burn_core/module/
reinit.rs

1use super::{Module, ModuleMapper};
2use burn_tensor::{
3    Element, ElementConversion, Tensor, TensorData,
4    backend::Backend,
5    ops::{FloatElem, IntElem},
6};
7use rand::{Rng, SeedableRng};
8
9#[derive(Debug)]
10/// Overrides float and int tensors of [burn modules](super::Module).
11///
12/// This is useful for testing.
13pub struct Reinitializer<B: Backend> {
14    float: ReinitStrategy<FloatElem<B>>,
15    int: ReinitStrategy<IntElem<B>>,
16}
17
18#[derive(Debug)]
19#[allow(missing_docs)]
20enum ReinitStrategy<E> {
21    Range { min: E, max: E },
22    Constant { value: E },
23    Random { seed: u64, min: E, max: E },
24}
25
26impl<B: Backend> Default for Reinitializer<B> {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<B: Backend> Reinitializer<B> {
33    /// Create a new [reinitializer](Reinitializer).
34    pub fn new() -> Self {
35        Self {
36            float: ReinitStrategy::Constant {
37                value: 0.elem::<FloatElem<B>>(),
38            },
39            int: ReinitStrategy::Constant {
40                value: 0.elem::<IntElem<B>>(),
41            },
42        }
43    }
44
45    /// Apply the reinitialization to the given [module](Module).
46    pub fn apply<M: Module<B>>(mut self, module: M) -> M {
47        module.map(&mut self)
48    }
49
50    /// Set the reinitialization strategy to constant for all tensors.
51    pub fn constant(self, constant: f64) -> Self {
52        self.constant_float(constant).constant_int(constant as i64)
53    }
54
55    /// Set the reinitialization strategy to constant for float tensors.
56    pub fn constant_float(mut self, constant: f64) -> Self {
57        self.float = ReinitStrategy::Constant {
58            value: constant.elem(),
59        };
60        self
61    }
62
63    /// Set the reinitialization strategy to constant for int tensors.
64    pub fn constant_int(mut self, constant: i64) -> Self {
65        self.int = ReinitStrategy::Constant {
66            value: constant.elem(),
67        };
68        self
69    }
70    /// Set the reinitialization strategy to random for all tensors.
71    pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
72        self.random_float(seed, min, max)
73            .random_int(seed, min as i64, max as i64)
74    }
75
76    /// Set the reinitialization strategy to random for float tensors.
77    pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
78        self.float = ReinitStrategy::Random {
79            seed,
80            min: min.elem(),
81            max: max.elem(),
82        };
83        self
84    }
85
86    /// Set the reinitialization strategy to random for int tensors.
87    pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
88        self.int = ReinitStrategy::Random {
89            seed,
90            min: min.elem(),
91            max: max.elem(),
92        };
93        self
94    }
95
96    /// Set the reinitialization strategy to range for all tensors.
97    pub fn range(self, min: f64, max: f64) -> Self {
98        self.range_float(min, max).range_int(min as i64, max as i64)
99    }
100
101    /// Set the reinitialization strategy to range for float tensors.
102    pub fn range_float(mut self, min: f64, max: f64) -> Self {
103        self.float = ReinitStrategy::Range {
104            min: min.elem(),
105            max: max.elem(),
106        };
107        self
108    }
109
110    /// Set the reinitialization strategy to range for int tensors.
111    pub fn range_int(mut self, min: i64, max: i64) -> Self {
112        self.int = ReinitStrategy::Range {
113            min: min.elem(),
114            max: max.elem(),
115        };
116        self
117    }
118}
119
120impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
121    fn map_float<const D: usize>(
122        &mut self,
123        param: super::Param<Tensor<B, D>>,
124    ) -> super::Param<Tensor<B, D>> {
125        let (id, tensor, mapper) = param.consume();
126        let device = tensor.device();
127        let shape = tensor.shape();
128        let num_elements = shape.num_elements();
129
130        let tensor = match &self.float {
131            ReinitStrategy::Range { min, max } => {
132                let tensor = Tensor::arange(0..num_elements as i64, &device)
133                    .reshape(shape)
134                    .float();
135                let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
136                tensor * factor + bias
137            }
138            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
139            ReinitStrategy::Random { seed, min, max } => {
140                let data = TensorData::new(
141                    random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
142                    shape,
143                );
144                Tensor::from_data(data, &device)
145            }
146        };
147
148        super::Param::from_mapped_value(id, tensor, mapper)
149    }
150
151    fn map_int<const D: usize>(
152        &mut self,
153        param: super::Param<Tensor<B, D, burn_tensor::Int>>,
154    ) -> super::Param<Tensor<B, D, burn_tensor::Int>> {
155        let (id, tensor, mapper) = param.consume();
156        let device = tensor.device();
157        let shape = tensor.shape();
158        let num_elements = shape.num_elements();
159
160        let tensor = match &self.int {
161            ReinitStrategy::Range { min, max } => {
162                let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
163                let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
164                tensor * factor + bias
165            }
166            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
167            ReinitStrategy::Random { seed, min, max } => {
168                let data = TensorData::new(
169                    random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
170                    shape,
171                );
172                Tensor::from_data(data, &device)
173            }
174        };
175
176        super::Param::from_mapped_value(id, tensor, mapper)
177    }
178
179    fn map_bool<const D: usize>(
180        &mut self,
181        param: super::Param<Tensor<B, D, burn_tensor::Bool>>,
182    ) -> super::Param<Tensor<B, D, burn_tensor::Bool>> {
183        let (id, tensor, mapper) = param.consume();
184        super::Param::from_mapped_value(id, tensor, mapper)
185    }
186}
187
188fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
189    let range = max.elem::<f64>() - min.elem::<f64>();
190    let factor = range / num_elements as f64;
191    let bias = min.elem::<f64>();
192
193    (factor.elem(), bias.elem())
194}
195
196fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
197    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
198    let dist = rand::distr::Uniform::new(min, max).unwrap();
199    (0..num_elements)
200        .map(|_| rng.sample(dist))
201        .map(|e| e.elem::<E>())
202        .collect()
203}