Skip to main content

yscv_model/layers/
activation.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// Stateless ReLU layer.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub struct ReLULayer;
9
10impl ReLULayer {
11    pub fn new() -> Self {
12        Self
13    }
14
15    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
16        graph.relu(input).map_err(Into::into)
17    }
18}
19
20/// Stateless LeakyReLU layer with configurable negative slope.
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub struct LeakyReLULayer {
23    negative_slope: f32,
24}
25
26impl LeakyReLULayer {
27    pub fn new(negative_slope: f32) -> Result<Self, ModelError> {
28        if !negative_slope.is_finite() || negative_slope < 0.0 {
29            return Err(ModelError::InvalidLeakyReluSlope { negative_slope });
30        }
31        Ok(Self { negative_slope })
32    }
33
34    pub fn negative_slope(&self) -> f32 {
35        self.negative_slope
36    }
37
38    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
39        let positive = graph.relu(input)?;
40        let zero = graph.constant(Tensor::scalar(0.0));
41        let neg_input = graph.sub(zero, input)?;
42        let negative_magnitude = graph.relu(neg_input)?;
43        let slope = graph.constant(Tensor::scalar(self.negative_slope));
44        let scaled_negative = graph.mul(negative_magnitude, slope)?;
45        graph.sub(positive, scaled_negative).map_err(Into::into)
46    }
47}
48
49/// Stateless sigmoid activation layer.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub struct SigmoidLayer;
52
53impl SigmoidLayer {
54    pub fn new() -> Self {
55        Self
56    }
57
58    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
59        graph.sigmoid(input).map_err(Into::into)
60    }
61}
62
63/// Stateless tanh activation layer.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
65pub struct TanhLayer;
66
67impl TanhLayer {
68    pub fn new() -> Self {
69        Self
70    }
71
72    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
73        graph.tanh(input).map_err(Into::into)
74    }
75}
76
77/// GELU activation layer.
78#[derive(Debug, Clone)]
79pub struct GELULayer;
80
81impl Default for GELULayer {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl GELULayer {
88    pub fn new() -> Self {
89        Self
90    }
91
92    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
93        graph.gelu(input).map_err(Into::into)
94    }
95
96    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
97        Ok(yscv_kernels::gelu(input))
98    }
99}
100
101/// SiLU (Swish) activation layer.
102#[derive(Debug, Clone)]
103pub struct SiLULayer;
104
105impl Default for SiLULayer {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl SiLULayer {
112    pub fn new() -> Self {
113        Self
114    }
115
116    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
117        graph.silu(input).map_err(Into::into)
118    }
119
120    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
121        Ok(yscv_kernels::silu(input))
122    }
123}
124
125/// Mish activation layer.
126#[derive(Debug, Clone)]
127pub struct MishLayer;
128
129impl Default for MishLayer {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl MishLayer {
136    pub fn new() -> Self {
137        Self
138    }
139
140    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
141        graph.mish(input).map_err(Into::into)
142    }
143
144    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
145        Ok(yscv_kernels::mish(input))
146    }
147}
148
149/// PReLU activation layer.
150/// Uses per-channel or single alpha for the negative slope.
151#[derive(Debug, Clone)]
152pub struct PReLULayer {
153    alpha: Vec<f32>,
154    alpha_node: Option<NodeId>,
155}
156
157impl PReLULayer {
158    pub fn new(alpha: Vec<f32>) -> Self {
159        Self {
160            alpha,
161            alpha_node: None,
162        }
163    }
164
165    pub fn alpha(&self) -> &[f32] {
166        &self.alpha
167    }
168
169    pub fn alpha_node(&self) -> Option<NodeId> {
170        self.alpha_node
171    }
172
173    pub fn register_params(&mut self, graph: &mut Graph) {
174        self.alpha_node = Some(
175            graph.variable(
176                Tensor::from_vec(vec![self.alpha.len()], self.alpha.clone())
177                    .expect("shape matches data"),
178            ),
179        );
180    }
181
182    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
183        if let Some(a_id) = self.alpha_node {
184            self.alpha = graph.value(a_id)?.data().to_vec();
185        }
186        Ok(())
187    }
188
189    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
190        let a_id = self
191            .alpha_node
192            .ok_or(ModelError::ParamsNotRegistered { layer: "PReLU" })?;
193        graph.prelu(input, a_id).map_err(Into::into)
194    }
195
196    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
197        let data = input.data();
198        let out: Vec<f32> = if self.alpha.len() == 1 {
199            let a = self.alpha[0];
200            data.iter()
201                .map(|&x| if x > 0.0 { x } else { a * x })
202                .collect()
203        } else {
204            let shape = input.shape();
205            let channels = if shape.len() >= 2 { shape[1] } else { 1 };
206            let spatial: usize = shape[2..].iter().product();
207            let mut result = data.to_vec();
208            for (i, v) in result.iter_mut().enumerate() {
209                let c = (i / spatial) % channels;
210                let a = self.alpha.get(c).copied().unwrap_or(0.01);
211                if *v < 0.0 {
212                    *v *= a;
213                }
214            }
215            result
216        };
217        Tensor::from_vec(input.shape().to_vec(), out).map_err(ModelError::Tensor)
218    }
219}