burn_nn/activation/
shrink.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::shrink;
8use burn::tensor::backend::Backend;
9
10#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Shrink {
19 pub lambda: f64,
21 pub bias: f64,
24}
25
26#[derive(Config, Debug)]
28pub struct ShrinkConfig {
29 #[config(default = "0.5")]
31 pub lambda: f64,
32 #[config(default = "0.5")]
34 pub bias: f64,
35}
36
37impl ShrinkConfig {
38 pub fn init(&self) -> Shrink {
40 Shrink {
41 lambda: self.lambda,
42 bias: self.bias,
43 }
44 }
45}
46
47impl ModuleDisplay for Shrink {
48 fn custom_settings(&self) -> Option<DisplaySettings> {
49 DisplaySettings::new()
50 .with_new_line_after_attribute(false)
51 .optional()
52 }
53
54 fn custom_content(&self, content: Content) -> Option<Content> {
55 content
56 .add("lambda", &self.lambda)
57 .add("bias", &self.bias)
58 .optional()
59 }
60}
61
62impl Shrink {
63 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
71 shrink(input, self.lambda, self.bias)
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::TestBackend;
79 use burn::tensor::TensorData;
80
81 #[test]
82 fn test_shrink_forward() {
83 let device = Default::default();
84 let model: Shrink = ShrinkConfig::new().init();
85 let input =
86 Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);
87 let out = model.forward(input);
88 let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);
89 assert_eq!(out.into_data(), expected);
90 }
91
92 #[test]
93 fn test_shrink_with_lambda_and_bias() {
94 let device = Default::default();
95 let model: Shrink = ShrinkConfig::new()
96 .with_lambda(0.25)
97 .with_bias(0.125)
98 .init();
99 let input =
100 Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);
101 let out = model.forward(input);
102 let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]);
103 assert_eq!(out.into_data(), expected);
104 }
105
106 #[test]
107 fn display() {
108 let config = ShrinkConfig::new().init();
109 assert_eq!(
110 alloc::format!("{config}"),
111 "Shrink {lambda: 0.5, bias: 0.5}"
112 );
113 }
114}