burn_core/module/
reinit.rs

1use super::{Module, ModuleMapper, ParamId};
2use burn_tensor::{
3    Element, ElementConversion, Tensor, TensorData,
4    backend::Backend,
5    ops::{FloatElem, IntElem},
6};
7use rand::{Rng, SeedableRng};
8
9#[derive(Debug)]
10/// Overrides float and int tensors of [burn modules](super::Module).
11///
12/// This is useful for testing.
13pub struct Reinitializer<B: Backend> {
14    float: ReinitStrategy<FloatElem<B>>,
15    int: ReinitStrategy<IntElem<B>>,
16}
17
18#[derive(Debug)]
19#[allow(missing_docs)]
20enum ReinitStrategy<E> {
21    Range { min: E, max: E },
22    Constant { value: E },
23    Random { seed: u64, min: E, max: E },
24}
25
26impl<B: Backend> Default for Reinitializer<B> {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<B: Backend> Reinitializer<B> {
33    /// Create a new [reinitializer](Reinitializer).
34    pub fn new() -> Self {
35        Self {
36            float: ReinitStrategy::Constant {
37                value: 0.elem::<FloatElem<B>>(),
38            },
39            int: ReinitStrategy::Constant {
40                value: 0.elem::<IntElem<B>>(),
41            },
42        }
43    }
44
45    /// Apply the reinitialization to the given [module](Module).
46    pub fn apply<M: Module<B>>(mut self, module: M) -> M {
47        module.map(&mut self)
48    }
49
50    /// Set the reinitialization strategy to constant for all tensors.
51    pub fn constant(self, constant: f64) -> Self {
52        self.constant_float(constant).constant_int(constant as i64)
53    }
54
55    /// Set the reinitialization strategy to constant for float tensors.
56    pub fn constant_float(mut self, constant: f64) -> Self {
57        self.float = ReinitStrategy::Constant {
58            value: constant.elem(),
59        };
60        self
61    }
62
63    /// Set the reinitialization strategy to constant for int tensors.
64    pub fn constant_int(mut self, constant: i64) -> Self {
65        self.int = ReinitStrategy::Constant {
66            value: constant.elem(),
67        };
68        self
69    }
70    /// Set the reinitialization strategy to random for all tensors.
71    pub fn random(self, seed: u64, min: f64, max: f64) -> Self {
72        self.random_float(seed, min, max)
73            .random_int(seed, min as i64, max as i64)
74    }
75
76    /// Set the reinitialization strategy to random for float tensors.
77    pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {
78        self.float = ReinitStrategy::Random {
79            seed,
80            min: min.elem(),
81            max: max.elem(),
82        };
83        self
84    }
85
86    /// Set the reinitialization strategy to random for int tensors.
87    pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {
88        self.int = ReinitStrategy::Random {
89            seed,
90            min: min.elem(),
91            max: max.elem(),
92        };
93        self
94    }
95
96    /// Set the reinitialization strategy to range for all tensors.
97    pub fn range(self, min: f64, max: f64) -> Self {
98        self.range_float(min, max).range_int(min as i64, max as i64)
99    }
100
101    /// Set the reinitialization strategy to range for float tensors.
102    pub fn range_float(mut self, min: f64, max: f64) -> Self {
103        self.float = ReinitStrategy::Range {
104            min: min.elem(),
105            max: max.elem(),
106        };
107        self
108    }
109
110    /// Set the reinitialization strategy to range for int tensors.
111    pub fn range_int(mut self, min: i64, max: i64) -> Self {
112        self.int = ReinitStrategy::Range {
113            min: min.elem(),
114            max: max.elem(),
115        };
116        self
117    }
118}
119
120impl<B: Backend> ModuleMapper<B> for Reinitializer<B> {
121    fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
122        let device = tensor.device();
123        let shape = tensor.shape();
124        let num_elements = shape.num_elements();
125
126        match &self.float {
127            ReinitStrategy::Range { min, max } => {
128                let tensor = Tensor::arange(0..num_elements as i64, &device)
129                    .reshape(shape)
130                    .float();
131                let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);
132                tensor * factor + bias
133            }
134            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
135            ReinitStrategy::Random { seed, min, max } => {
136                let data = TensorData::new(
137                    random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),
138                    shape,
139                );
140                Tensor::from_data(data, &device)
141            }
142        }
143    }
144
145    fn map_int<const D: usize>(
146        &mut self,
147        _id: ParamId,
148        tensor: Tensor<B, D, burn_tensor::Int>,
149    ) -> Tensor<B, D, burn_tensor::Int> {
150        let device = tensor.device();
151        let shape = tensor.shape();
152        let num_elements = shape.num_elements();
153
154        match &self.int {
155            ReinitStrategy::Range { min, max } => {
156                let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);
157                let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);
158                tensor * factor + bias
159            }
160            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),
161            ReinitStrategy::Random { seed, min, max } => {
162                let data = TensorData::new(
163                    random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),
164                    shape,
165                );
166                Tensor::from_data(data, &device)
167            }
168        }
169    }
170
171    fn map_bool<const D: usize>(
172        &mut self,
173        _id: ParamId,
174        tensor: Tensor<B, D, burn_tensor::Bool>,
175    ) -> Tensor<B, D, burn_tensor::Bool> {
176        tensor
177    }
178}
179
180fn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {
181    let range = max.elem::<f64>() - min.elem::<f64>();
182    let factor = range / num_elements as f64;
183    let bias = min.elem::<f64>();
184
185    (factor.elem(), bias.elem())
186}
187
188fn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {
189    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
190    let dist = rand::distr::Uniform::new(min, max).unwrap();
191    (0..num_elements)
192        .map(|_| rng.sample(dist))
193        .map(|e| e.elem::<E>())
194        .collect()
195}