Skip to main content

burn_nn/activation/
shrink.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::shrink;
8use burn::tensor::backend::Backend;
9
10/// Shrink layer.
11///
12/// Applies the Shrink function element-wise:
13/// `shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`
14///
15/// Should be created with [ShrinkConfig](ShrinkConfig).
16#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Shrink {
19    /// The lambda value for the Shrink formulation.
20    pub lambda: f64,
21    /// The bias value for the Shrink formulation.
22    // Usually bias = lambda, but need this to handle onnx spec https://onnx.ai/onnx/operators/onnx__Shrink.html
23    pub bias: f64,
24}
25
26/// Configuration to create a [Shrink](Shrink) layer using the [init function](ShrinkConfig::init).
27#[derive(Config, Debug)]
28pub struct ShrinkConfig {
29    /// The lambda value for the Shrink formulation. Default is 0.5
30    #[config(default = "0.5")]
31    pub lambda: f64,
32    /// The bias value for the Shrink formulation. Default is 0.5.
33    #[config(default = "0.5")]
34    pub bias: f64,
35}
36
37impl ShrinkConfig {
38    /// Initialize a new [Shrink](Shrink) Layer
39    pub fn init(&self) -> Shrink {
40        Shrink {
41            lambda: self.lambda,
42            bias: self.bias,
43        }
44    }
45}
46
47impl ModuleDisplay for Shrink {
48    fn custom_settings(&self) -> Option<DisplaySettings> {
49        DisplaySettings::new()
50            .with_new_line_after_attribute(false)
51            .optional()
52    }
53
54    fn custom_content(&self, content: Content) -> Option<Content> {
55        content
56            .add("lambda", &self.lambda)
57            .add("bias", &self.bias)
58            .optional()
59    }
60}
61
62impl Shrink {
63    /// Forward pass for the Shrink layer.
64    ///
65    /// See [shrink](burn::tensor::activation::shrink) for more information.
66    ///
67    /// # Shapes
68    /// - input: `[..., any]`
69    /// - output: `[..., any]`
70    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
71        shrink(input, self.lambda, self.bias)
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::TestBackend;
79    use burn::tensor::TensorData;
80
81    #[test]
82    fn test_shrink_forward() {
83        let device = Default::default();
84        let model: Shrink = ShrinkConfig::new().init();
85        let input =
86            Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
87        let out = model.forward(input);
88        let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
89        assert_eq!(out.into_data(), expected);
90    }
91
92    #[test]
93    fn test_shrink_with_lambda_and_bias() {
94        let device = Default::default();
95        let model: Shrink = ShrinkConfig::new()
96            .with_lambda(0.25)
97            .with_bias(0.125)
98            .init();
99        let input =
100            Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
101        let out = model.forward(input);
102        let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]);
103        assert_eq!(out.into_data(), expected);
104    }
105
106    #[test]
107    fn display() {
108        let config = ShrinkConfig::new().init();
109        assert_eq!(
110            alloc::format!("{config}"),
111            "Shrink {lambda: 0.5, bias: 0.5}"
112        );
113    }
114}