1use super::{Module, ModuleMapper};
2use burn_tensor::{
3 Element, ElementConversion, Tensor, TensorData,
4 backend::Backend,
5 ops::{FloatElem, IntElem},
6};
7use rand::{Rng, SeedableRng};
8
9#[derive(Debug)]
10pub struct Reinitializer<B: Backend> {
14 float: ReinitStrategy<FloatElem<B>>,
15 int: ReinitStrategy<IntElem<B>>,
16}
17
18#[derive(Debug)]
19#[allow(missing_docs)]
20enum ReinitStrategy<E> {
21 Range { min: E, max: E },
22 Constant { value: E },
23 Random { seed: u64, min: E, max: E },
24}
25
26impl<B: Backend> Default for Reinitializer<B> {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<B: Backend> Reinitializer<B> {
33 pub fn new() -> Self {
35 Self {
36 float: ReinitStrategy::Constant {
37 value: 0.elem::<FloatElem<B>>(),
38 },
39 int: ReinitStrategy::Constant {
40 value: 0.elem::<IntElem<B>>(),
41 },
42 }
43 }
44
45 pub fn apply<M: Module<B>>(mut self, module: M) -> M {
47 module.map(&mut self)
48 }
49
50 pub fn constant(self, constant: f64) -> Self {
52 self.constant_float(constant).constant_int(constant as i64)
53 }
54
55 pub fn constant_float(mut self, constant: f64) -> Self {
57 self.float = ReinitStrategy::Constant {
58 value: constant.elem(),
59 };
60 self
61 }
62
63 pub fn constant_int(mut self, constant: i64) -> Self {
65 self.int = ReinitStrategy::Constant {
66 value: constant.elem(),
67 };
68 self
69 }
70 pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
72 self.random_float(seed, min, max)
73 .random_int(seed, min as i64, max as i64)
74 }
75
76 pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
78 self.float = ReinitStrategy::Random {
79 seed,
80 min: min.elem(),
81 max: max.elem(),
82 };
83 self
84 }
85
86 pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
88 self.int = ReinitStrategy::Random {
89 seed,
90 min: min.elem(),
91 max: max.elem(),
92 };
93 self
94 }
95
96 pub fn range(self, min: f64, max: f64) -> Self {
98 self.range_float(min, max).range_int(min as i64, max as i64)
99 }
100
101 pub fn range_float(mut self, min: f64, max: f64) -> Self {
103 self.float = ReinitStrategy::Range {
104 min: min.elem(),
105 max: max.elem(),
106 };
107 self
108 }
109
110 pub fn range_int(mut self, min: i64, max: i64) -> Self {
112 self.int = ReinitStrategy::Range {
113 min: min.elem(),
114 max: max.elem(),
115 };
116 self
117 }
118}
119
120impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
121 fn map_float<const D: usize>(
122 &mut self,
123 param: super::Param<Tensor<B, D>>,
124 ) -> super::Param<Tensor<B, D>> {
125 let (id, tensor, mapper) = param.consume();
126 let device = tensor.device();
127 let shape = tensor.shape();
128 let num_elements = shape.num_elements();
129
130 let tensor = match &self.float {
131 ReinitStrategy::Range { min, max } => {
132 let tensor = Tensor::arange(0..num_elements as i64, &device)
133 .reshape(shape)
134 .float();
135 let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
136 tensor * factor + bias
137 }
138 ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
139 ReinitStrategy::Random { seed, min, max } => {
140 let data = TensorData::new(
141 random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
142 shape,
143 );
144 Tensor::from_data(data, &device)
145 }
146 };
147
148 super::Param::from_mapped_value(id, tensor, mapper)
149 }
150
151 fn map_int<const D: usize>(
152 &mut self,
153 param: super::Param<Tensor<B, D, burn_tensor::Int>>,
154 ) -> super::Param<Tensor<B, D, burn_tensor::Int>> {
155 let (id, tensor, mapper) = param.consume();
156 let device = tensor.device();
157 let shape = tensor.shape();
158 let num_elements = shape.num_elements();
159
160 let tensor = match &self.int {
161 ReinitStrategy::Range { min, max } => {
162 let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
163 let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
164 tensor * factor + bias
165 }
166 ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
167 ReinitStrategy::Random { seed, min, max } => {
168 let data = TensorData::new(
169 random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
170 shape,
171 );
172 Tensor::from_data(data, &device)
173 }
174 };
175
176 super::Param::from_mapped_value(id, tensor, mapper)
177 }
178
179 fn map_bool<const D: usize>(
180 &mut self,
181 param: super::Param<Tensor<B, D, burn_tensor::Bool>>,
182 ) -> super::Param<Tensor<B, D, burn_tensor::Bool>> {
183 let (id, tensor, mapper) = param.consume();
184 super::Param::from_mapped_value(id, tensor, mapper)
185 }
186}
187
188fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
189 let range = max.elem::<f64>() - min.elem::<f64>();
190 let factor = range / num_elements as f64;
191 let bias = min.elem::<f64>();
192
193 (factor.elem(), bias.elem())
194}
195
196fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
197 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
198 let dist = rand::distr::Uniform::new(min, max).unwrap();
199 (0..num_elements)
200 .map(|_| rng.sample(dist))
201 .map(|e| e.elem::<E>())
202 .collect()
203}