Skip to main content

lumen_core/grad/
op.rs

1use crate::{FloatDType, IntTensor, Tensor, WithDType};
2
3#[derive(Clone)]
4pub enum Op<T: FloatDType> {
5    Binary(Tensor<T>, Tensor<T>, BinaryOp),
6    BinaryScalarRhs(Tensor<T>, T, BinaryOp),
7    BinaryScalarLhs(T, Tensor<T>, BinaryOp),
8    Unary(Tensor<T>, UnaryOp<T>),
9    Pow(Tensor<T>, T),
10    Reduce(Tensor<T>, ReduceOp, Vec<usize>),
11    Matmul(Tensor<T>, Tensor<T>),
12    Broadcast(Tensor<T>),
13    Narrow(Tensor<T>, usize, usize, usize),
14    Slice(Tensor<T>, usize, usize, usize, usize),
15    IndexSelect(Tensor<T>, IntTensor, usize),
16    IndexAdd(Tensor<T>, IntTensor, Tensor<T>, usize),
17    ScatterAdd(Tensor<T>, IntTensor, Tensor<T>, usize),
18    Gather(Tensor<T>, IntTensor, usize),
19    Reshape(Tensor<T>),
20    Transpose(Tensor<T>, usize, usize),
21    Permute(Tensor<T>, Vec<usize>),
22    Cat(Vec<Tensor<T>>, usize),
23    IfElse(Tensor<bool>, Option<Tensor<T>>, Option<Tensor<T>>),
24    Copy(Tensor<T>),
25}
26
27#[derive(Clone, Copy, PartialEq, Eq)]
28pub enum CmpOp {
29    Eq,
30    Ne,
31    Le,
32    Ge,
33    Lt,
34    Gt,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ReduceOp {
39    Sum,
40    Min,
41    Max,
42    Mean,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum BinaryOp {
47    Add,
48    Mul,
49    Sub,
50    Div,
51    Maximum,
52    Minimum,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum UnaryOp<T: WithDType> {
57    Exp,
58    Ln,
59
60    Sin,
61    Cos,
62    Tanh,
63
64    Abs,
65    Neg,
66    Sqr,
67    Sqrt,
68
69    Recip,
70    Gelu,
71    GeluErf,
72    Erf,
73    Relu,
74    LeakyRelu(T),
75    Silu,
76    Sigmoid,
77
78    Floor,
79    Ceil,
80    Round,
81    Sign,
82}