oxigdal_ml/optimization/distillation/
network.rs1#[derive(Debug, Clone)]
5pub struct SimpleRng {
6 state: u64,
7}
8
9impl SimpleRng {
10 #[must_use]
12 pub fn new(seed: u64) -> Self {
13 Self { state: seed.max(1) }
14 }
15
16 pub fn next_u64(&mut self) -> u64 {
18 self.state ^= self.state << 13;
19 self.state ^= self.state >> 7;
20 self.state ^= self.state << 17;
21 self.state
22 }
23
24 pub fn next_f32(&mut self) -> f32 {
26 (self.next_u64() as f64 / u64::MAX as f64) as f32
27 }
28
29 pub fn next_normal(&mut self) -> f32 {
31 let u1 = self.next_f32().max(1e-10);
32 let u2 = self.next_f32();
33 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
34 }
35
36 pub fn shuffle<T>(&mut self, slice: &mut [T]) {
38 for i in (1..slice.len()).rev() {
39 let j = (self.next_u64() as usize) % (i + 1);
40 slice.swap(i, j);
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct DenseLayer {
48 pub weights: Vec<f32>,
50 pub bias: Vec<f32>,
52 pub input_size: usize,
54 pub output_size: usize,
56}
57
58impl DenseLayer {
59 #[must_use]
61 pub fn new(input_size: usize, output_size: usize, seed: u64) -> Self {
62 let scale = (2.0 / (input_size + output_size) as f32).sqrt();
63 let mut rng = SimpleRng::new(seed);
64
65 let weights: Vec<f32> = (0..input_size * output_size)
66 .map(|_| rng.next_normal() * scale)
67 .collect();
68
69 let bias = vec![0.0; output_size];
70
71 Self {
72 weights,
73 bias,
74 input_size,
75 output_size,
76 }
77 }
78
79 #[must_use]
81 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
82 let mut output = self.bias.clone();
83
84 for (o_idx, out) in output.iter_mut().enumerate() {
85 for (i_idx, &inp) in input.iter().enumerate() {
86 let w_idx = o_idx * self.input_size + i_idx;
87 if let Some(&w) = self.weights.get(w_idx) {
88 *out += inp * w;
89 }
90 }
91 }
92
93 output
94 }
95
96 #[must_use]
98 pub fn backward(&self, input: &[f32], grad_output: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
99 let mut grad_weights = vec![0.0; self.weights.len()];
101 for (o_idx, &go) in grad_output.iter().enumerate() {
102 for (i_idx, &inp) in input.iter().enumerate() {
103 let w_idx = o_idx * self.input_size + i_idx;
104 if w_idx < grad_weights.len() {
105 grad_weights[w_idx] += go * inp;
106 }
107 }
108 }
109
110 let grad_bias = grad_output.to_vec();
112
113 let mut grad_input = vec![0.0; self.input_size];
115 for (o_idx, &go) in grad_output.iter().enumerate() {
116 for (i_idx, gi) in grad_input.iter_mut().enumerate() {
117 let w_idx = o_idx * self.input_size + i_idx;
118 if let Some(&w) = self.weights.get(w_idx) {
119 *gi += go * w;
120 }
121 }
122 }
123
124 (grad_weights, grad_bias, grad_input)
125 }
126
127 #[must_use]
129 pub fn num_params(&self) -> usize {
130 self.weights.len() + self.bias.len()
131 }
132
133 #[must_use]
135 pub fn get_params(&self) -> Vec<f32> {
136 let mut params = self.weights.clone();
137 params.extend(&self.bias);
138 params
139 }
140
141 pub fn set_params(&mut self, params: &[f32]) {
143 let w_end = self.weights.len();
144 let b_len = self.bias.len();
145 if params.len() >= w_end + b_len {
146 self.weights.copy_from_slice(¶ms[..w_end]);
147 self.bias.copy_from_slice(¶ms[w_end..w_end + b_len]);
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct ForwardCache {
155 pub input: Vec<f32>,
157 pub hidden_pre: Vec<f32>,
159 pub hidden_post: Vec<f32>,
161}
162
163#[derive(Debug, Clone)]
165pub struct MLPGradients {
166 pub hidden_weights: Vec<f32>,
168 pub hidden_bias: Vec<f32>,
170 pub output_weights: Vec<f32>,
172 pub output_bias: Vec<f32>,
174}
175
176impl MLPGradients {
177 #[must_use]
179 pub fn flatten(&self) -> Vec<f32> {
180 let mut flat = self.hidden_weights.clone();
181 flat.extend(&self.hidden_bias);
182 flat.extend(&self.output_weights);
183 flat.extend(&self.output_bias);
184 flat
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct SimpleMLP {
191 pub hidden: DenseLayer,
193 pub output: DenseLayer,
195}
196
197impl SimpleMLP {
198 #[must_use]
200 pub fn new(input_size: usize, hidden_size: usize, output_size: usize, seed: u64) -> Self {
201 Self {
202 hidden: DenseLayer::new(input_size, hidden_size, seed),
203 output: DenseLayer::new(hidden_size, output_size, seed.wrapping_add(1)),
204 }
205 }
206
207 #[must_use]
209 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
210 let hidden_out = self.hidden.forward(input);
211 let hidden_activated: Vec<f32> = hidden_out.iter().map(|&x| x.max(0.0)).collect();
213 self.output.forward(&hidden_activated)
214 }
215
216 #[must_use]
218 pub fn forward_with_cache(&self, input: &[f32]) -> (Vec<f32>, ForwardCache) {
219 let hidden_pre = self.hidden.forward(input);
220 let hidden_post: Vec<f32> = hidden_pre.iter().map(|&x| x.max(0.0)).collect();
221 let output = self.output.forward(&hidden_post);
222
223 let cache = ForwardCache {
224 input: input.to_vec(),
225 hidden_pre,
226 hidden_post,
227 };
228
229 (output, cache)
230 }
231
232 pub fn backward(&self, grad_output: &[f32], cache: &ForwardCache) -> MLPGradients {
234 let (grad_out_weights, grad_out_bias, grad_hidden) =
236 self.output.backward(&cache.hidden_post, grad_output);
237
238 let grad_hidden_pre: Vec<f32> = grad_hidden
240 .iter()
241 .zip(cache.hidden_pre.iter())
242 .map(|(&g, &h)| if h > 0.0 { g } else { 0.0 })
243 .collect();
244
245 let (grad_hidden_weights, grad_hidden_bias, _) =
247 self.hidden.backward(&cache.input, &grad_hidden_pre);
248
249 MLPGradients {
250 hidden_weights: grad_hidden_weights,
251 hidden_bias: grad_hidden_bias,
252 output_weights: grad_out_weights,
253 output_bias: grad_out_bias,
254 }
255 }
256
257 #[must_use]
259 pub fn num_params(&self) -> usize {
260 self.hidden.num_params() + self.output.num_params()
261 }
262
263 #[must_use]
265 pub fn get_params(&self) -> Vec<f32> {
266 let mut params = self.hidden.get_params();
267 params.extend(self.output.get_params());
268 params
269 }
270
271 pub fn set_params(&mut self, params: &[f32]) {
273 let hidden_size = self.hidden.num_params();
274 self.hidden.set_params(¶ms[..hidden_size]);
275 self.output.set_params(¶ms[hidden_size..]);
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_simple_rng() {
285 let mut rng = SimpleRng::new(42);
286
287 let val1 = rng.next_u64();
288
289 let mut rng2 = SimpleRng::new(42);
290 let val2 = rng2.next_u64();
291
292 assert_eq!(val1, val2);
293
294 let mut rng3 = SimpleRng::new(123);
295 for _ in 0..100 {
296 let f = rng3.next_f32();
297 assert!((0.0..1.0).contains(&f));
298 }
299 }
300
301 #[test]
302 fn test_dense_layer_forward() {
303 let layer = DenseLayer::new(4, 3, 42);
304 let input = vec![1.0, 2.0, 3.0, 4.0];
305 let output = layer.forward(&input);
306
307 assert_eq!(output.len(), 3);
308 for &o in &output {
309 assert!(o.is_finite());
310 }
311 }
312
313 #[test]
314 fn test_dense_layer_backward() {
315 let layer = DenseLayer::new(4, 3, 42);
316 let input = vec![1.0, 2.0, 3.0, 4.0];
317 let grad_output = vec![0.1, 0.2, 0.3];
318
319 let (grad_w, grad_b, grad_i) = layer.backward(&input, &grad_output);
320
321 assert_eq!(grad_w.len(), 4 * 3);
322 assert_eq!(grad_b.len(), 3);
323 assert_eq!(grad_i.len(), 4);
324 }
325
326 #[test]
327 fn test_simple_mlp_forward() {
328 let mlp = SimpleMLP::new(10, 20, 5, 42);
329 let input = vec![0.1; 10];
330 let output = mlp.forward(&input);
331
332 assert_eq!(output.len(), 5);
333 for &o in &output {
334 assert!(o.is_finite());
335 }
336 }
337
338 #[test]
339 fn test_simple_mlp_params() {
340 let mlp = SimpleMLP::new(10, 20, 5, 42);
341 let params = mlp.get_params();
342
343 assert_eq!(params.len(), 325);
345 assert_eq!(mlp.num_params(), 325);
346 }
347}