1use crate::{FloatDType, IntTensor, Tensor, WithDType};
2
3#[derive(Clone)]
4pub enum Op<T: FloatDType> {
5 Binary(Tensor<T>, Tensor<T>, BinaryOp),
6 BinaryScalarRhs(Tensor<T>, T, BinaryOp),
7 BinaryScalarLhs(T, Tensor<T>, BinaryOp),
8 Unary(Tensor<T>, UnaryOp<T>),
9 Pow(Tensor<T>, T),
10 Reduce(Tensor<T>, ReduceOp, Vec<usize>),
11 Matmul(Tensor<T>, Tensor<T>),
12 Broadcast(Tensor<T>),
13 Narrow(Tensor<T>, usize, usize, usize),
14 Slice(Tensor<T>, usize, usize, usize, usize),
15 IndexSelect(Tensor<T>, IntTensor, usize),
16 IndexAdd(Tensor<T>, IntTensor, Tensor<T>, usize),
17 ScatterAdd(Tensor<T>, IntTensor, Tensor<T>, usize),
18 Gather(Tensor<T>, IntTensor, usize),
19 Reshape(Tensor<T>),
20 Transpose(Tensor<T>, usize, usize),
21 Permute(Tensor<T>, Vec<usize>),
22 Cat(Vec<Tensor<T>>, usize),
23 IfElse(Tensor<bool>, Option<Tensor<T>>, Option<Tensor<T>>),
24 Copy(Tensor<T>),
25}
26
27#[derive(Clone, Copy, PartialEq, Eq)]
28pub enum CmpOp {
29 Eq,
30 Ne,
31 Le,
32 Ge,
33 Lt,
34 Gt,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ReduceOp {
39 Sum,
40 Min,
41 Max,
42 Mean,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum BinaryOp {
47 Add,
48 Mul,
49 Sub,
50 Div,
51 Maximum,
52 Minimum,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum UnaryOp<T: WithDType> {
57 Exp,
58 Ln,
59
60 Sin,
61 Cos,
62 Tanh,
63
64 Abs,
65 Neg,
66 Sqr,
67 Sqrt,
68
69 Recip,
70 Gelu,
71 GeluErf,
72 Erf,
73 Relu,
74 LeakyRelu(T),
75 Silu,
76 Sigmoid,
77
78 Floor,
79 Ceil,
80 Round,
81 Sign,
82}