numrs/llo/
elementwise.rs

1/// Elementwise LLO operations and strategy hints
2/// This module describes elementwise ops (add, mul, ...) and possible
3/// execution strategies (vectorized, simple loop, GPU kernel, etc.)
4
5use serde::{Serialize, Deserialize};
6
7/// Binary elementwise kinds (add, mul, sub, div, pow, ...)
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ElementwiseBinaryKind {
10    Add,
11    Mul,
12    Sub,
13    Div,
14    Pow,
15}
16
17/// Unary elementwise kinds (sqrt, sin, cos, exp, log, ...)
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum ElementwiseUnaryKind {
20    Sqrt,
21    Abs,
22    Neg,
23    Exp,
24    Log,
25    Sin,
26    Cos,
27    Tan,
28    Asin,
29    Acos,
30    Atan,
31    Relu,
32    LeakyRelu,
33}
34
35/// Flattened ElementwiseKind kept for backward compatibility; prefer
36/// using `ElementwiseBinaryKind` / `ElementwiseUnaryKind` when possible.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum ElementwiseKind {
39    Add,
40    Mul,
41    Sub,
42    Div,
43    Pow,
44    /// Unary sqrt
45    Sqrt,
46    /// Trigonometric
47    Sin,
48    Cos,
49    Tan,
50    /// Inverse trigonometric
51    Asin,
52    Acos,
53    Atan,
54    /// Unary absolute value
55    Abs,
56    /// Unary negation
57    Neg,
58    /// Exponential
59    Exp,
60    /// Logarithm
61    Log,
62    /// Activation functions
63    Relu,
64    LeakyRelu,
65    Sigmoid,
66    Tanh,
67    Softplus,
68}
69
70/// Execution strategy selection enum. Backends may choose different variants.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72pub enum ElementwiseStrategy {
73    /// Let the runtime resolver pick the best concrete implementation.
74    Default,
75    Scalar,
76    Vectorized,
77    GpuKernel,
78}
79
80// Conversion helpers
81impl From<ElementwiseBinaryKind> for ElementwiseKind {
82    fn from(b: ElementwiseBinaryKind) -> Self {
83        match b {
84            ElementwiseBinaryKind::Add => ElementwiseKind::Add,
85            ElementwiseBinaryKind::Mul => ElementwiseKind::Mul,
86            ElementwiseBinaryKind::Sub => ElementwiseKind::Sub,
87            ElementwiseBinaryKind::Div => ElementwiseKind::Div,
88            ElementwiseBinaryKind::Pow => ElementwiseKind::Pow,
89        }
90    }
91}
92
93impl From<ElementwiseUnaryKind> for ElementwiseKind {
94    fn from(u: ElementwiseUnaryKind) -> Self {
95        match u {
96            ElementwiseUnaryKind::Sqrt => ElementwiseKind::Sqrt,
97            ElementwiseUnaryKind::Sin => ElementwiseKind::Sin,
98            ElementwiseUnaryKind::Cos => ElementwiseKind::Cos,
99            ElementwiseUnaryKind::Abs => ElementwiseKind::Abs,
100            ElementwiseUnaryKind::Neg => ElementwiseKind::Neg,
101            ElementwiseUnaryKind::Exp => ElementwiseKind::Exp,
102            ElementwiseUnaryKind::Log => ElementwiseKind::Log,
103            ElementwiseUnaryKind::Tan => ElementwiseKind::Tan,
104            ElementwiseUnaryKind::Asin => ElementwiseKind::Asin,
105            ElementwiseUnaryKind::Acos => ElementwiseKind::Acos,
106            ElementwiseUnaryKind::Atan => ElementwiseKind::Atan,
107            ElementwiseUnaryKind::Relu => ElementwiseKind::Relu,
108            ElementwiseUnaryKind::LeakyRelu => ElementwiseKind::LeakyRelu,
109        }
110    }
111}