1use super::{Module, ModuleMapper, ParamId};
2use burn_tensor::{
3 Element, ElementConversion, Tensor, TensorData,
4 backend::Backend,
5 ops::{FloatElem, IntElem},
6};
7use rand::{Rng, SeedableRng};
8
9#[derive(Debug)]
10pub struct Reinitializer<B: Backend> {
14 float: ReinitStrategy<FloatElem<B>>,
15 int: ReinitStrategy<IntElem<B>>,
16}
17
18#[derive(Debug)]
19#[allow(missing_docs)]
20enum ReinitStrategy<E> {
21 Range { min: E, max: E },
22 Constant { value: E },
23 Random { seed: u64, min: E, max: E },
24}
25
26impl<B: Backend> Default for Reinitializer<B> {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<B: Backend> Reinitializer<B> {
33 pub fn new() -> Self {
35 Self {
36 float: ReinitStrategy::Constant {
37 value: 0.elem::<FloatElem<B>>(),
38 },
39 int: ReinitStrategy::Constant {
40 value: 0.elem::<IntElem<B>>(),
41 },
42 }
43 }
44
45 pub fn apply<M: Module<B>>(mut self, module: M) -> M {
47 module.map(&mut self)
48 }
49
50 pub fn constant(self, constant: f64) -> Self {
52 self.constant_float(constant).constant_int(constant as i64)
53 }
54
55 pub fn constant_float(mut self, constant: f64) -> Self {
57 self.float = ReinitStrategy::Constant {
58 value: constant.elem(),
59 };
60 self
61 }
62
63 pub fn constant_int(mut self, constant: i64) -> Self {
65 self.int = ReinitStrategy::Constant {
66 value: constant.elem(),
67 };
68 self
69 }
70 pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
72 self.random_float(seed, min, max)
73 .random_int(seed, min as i64, max as i64)
74 }
75
76 pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
78 self.float = ReinitStrategy::Random {
79 seed,
80 min: min.elem(),
81 max: max.elem(),
82 };
83 self
84 }
85
86 pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
88 self.int = ReinitStrategy::Random {
89 seed,
90 min: min.elem(),
91 max: max.elem(),
92 };
93 self
94 }
95
96 pub fn range(self, min: f64, max: f64) -> Self {
98 self.range_float(min, max).range_int(min as i64, max as i64)
99 }
100
101 pub fn range_float(mut self, min: f64, max: f64) -> Self {
103 self.float = ReinitStrategy::Range {
104 min: min.elem(),
105 max: max.elem(),
106 };
107 self
108 }
109
110 pub fn range_int(mut self, min: i64, max: i64) -> Self {
112 self.int = ReinitStrategy::Range {
113 min: min.elem(),
114 max: max.elem(),
115 };
116 self
117 }
118}
119
120impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
121 fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
122 let device = tensor.device();
123 let shape = tensor.shape();
124 let num_elements = shape.num_elements();
125
126 match &self.float {
127 ReinitStrategy::Range { min, max } => {
128 let tensor = Tensor::arange(0..num_elements as i64, &device)
129 .reshape(shape)
130 .float();
131 let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
132 tensor * factor + bias
133 }
134 ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
135 ReinitStrategy::Random { seed, min, max } => {
136 let data = TensorData::new(
137 random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
138 shape,
139 );
140 Tensor::from_data(data, &device)
141 }
142 }
143 }
144
145 fn map_int<const D: usize>(
146 &mut self,
147 _id: ParamId,
148 tensor: Tensor<B, D, burn_tensor::Int>,
149 ) -> Tensor<B, D, burn_tensor::Int> {
150 let device = tensor.device();
151 let shape = tensor.shape();
152 let num_elements = shape.num_elements();
153
154 match &self.int {
155 ReinitStrategy::Range { min, max } => {
156 let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
157 let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
158 tensor * factor + bias
159 }
160 ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
161 ReinitStrategy::Random { seed, min, max } => {
162 let data = TensorData::new(
163 random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
164 shape,
165 );
166 Tensor::from_data(data, &device)
167 }
168 }
169 }
170
171 fn map_bool<const D: usize>(
172 &mut self,
173 _id: ParamId,
174 tensor: Tensor<B, D, burn_tensor::Bool>,
175 ) -> Tensor<B, D, burn_tensor::Bool> {
176 tensor
177 }
178}
179
180fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
181 let range = max.elem::<f64>() - min.elem::<f64>();
182 let factor = range / num_elements as f64;
183 let bias = min.elem::<f64>();
184
185 (factor.elem(), bias.elem())
186}
187
188fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
189 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
190 let dist = rand::distr::Uniform::new(min, max).unwrap();
191 (0..num_elements)
192 .map(|_| rng.sample(dist))
193 .map(|e| e.elem::<E>())
194 .collect()
195}