Skip to main content

oxigdal_ml/optimization/distillation/
network.rs

1//! Neural network components for knowledge distillation
2
3/// Simple xorshift-based RNG for reproducibility
4#[derive(Debug, Clone)]
5pub struct SimpleRng {
6    state: u64,
7}
8
9impl SimpleRng {
10    /// Creates a new RNG with given seed
11    #[must_use]
12    pub fn new(seed: u64) -> Self {
13        Self { state: seed.max(1) }
14    }
15
16    /// Generates next u64
17    pub fn next_u64(&mut self) -> u64 {
18        self.state ^= self.state << 13;
19        self.state ^= self.state >> 7;
20        self.state ^= self.state << 17;
21        self.state
22    }
23
24    /// Generates a random f32 in [0, 1)
25    pub fn next_f32(&mut self) -> f32 {
26        (self.next_u64() as f64 / u64::MAX as f64) as f32
27    }
28
29    /// Generates a normally distributed f32 using Box-Muller transform
30    pub fn next_normal(&mut self) -> f32 {
31        let u1 = self.next_f32().max(1e-10);
32        let u2 = self.next_f32();
33        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
34    }
35
36    /// Shuffles a slice in-place
37    pub fn shuffle<T>(&mut self, slice: &mut [T]) {
38        for i in (1..slice.len()).rev() {
39            let j = (self.next_u64() as usize) % (i + 1);
40            slice.swap(i, j);
41        }
42    }
43}
44
45/// A simple dense layer for demonstration purposes
46#[derive(Debug, Clone)]
47pub struct DenseLayer {
48    /// Weight matrix (flattened: input_size * output_size)
49    pub weights: Vec<f32>,
50    /// Bias vector
51    pub bias: Vec<f32>,
52    /// Input size
53    pub input_size: usize,
54    /// Output size
55    pub output_size: usize,
56}
57
58impl DenseLayer {
59    /// Creates a new dense layer with Xavier initialization
60    #[must_use]
61    pub fn new(input_size: usize, output_size: usize, seed: u64) -> Self {
62        let scale = (2.0 / (input_size + output_size) as f32).sqrt();
63        let mut rng = SimpleRng::new(seed);
64
65        let weights: Vec<f32> = (0..input_size * output_size)
66            .map(|_| rng.next_normal() * scale)
67            .collect();
68
69        let bias = vec![0.0; output_size];
70
71        Self {
72            weights,
73            bias,
74            input_size,
75            output_size,
76        }
77    }
78
79    /// Forward pass
80    #[must_use]
81    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
82        let mut output = self.bias.clone();
83
84        for (o_idx, out) in output.iter_mut().enumerate() {
85            for (i_idx, &inp) in input.iter().enumerate() {
86                let w_idx = o_idx * self.input_size + i_idx;
87                if let Some(&w) = self.weights.get(w_idx) {
88                    *out += inp * w;
89                }
90            }
91        }
92
93        output
94    }
95
96    /// Backward pass computing gradients w.r.t. weights, bias, and input
97    #[must_use]
98    pub fn backward(&self, input: &[f32], grad_output: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
99        // Gradient w.r.t. weights
100        let mut grad_weights = vec![0.0; self.weights.len()];
101        for (o_idx, &go) in grad_output.iter().enumerate() {
102            for (i_idx, &inp) in input.iter().enumerate() {
103                let w_idx = o_idx * self.input_size + i_idx;
104                if w_idx < grad_weights.len() {
105                    grad_weights[w_idx] += go * inp;
106                }
107            }
108        }
109
110        // Gradient w.r.t. bias
111        let grad_bias = grad_output.to_vec();
112
113        // Gradient w.r.t. input
114        let mut grad_input = vec![0.0; self.input_size];
115        for (o_idx, &go) in grad_output.iter().enumerate() {
116            for (i_idx, gi) in grad_input.iter_mut().enumerate() {
117                let w_idx = o_idx * self.input_size + i_idx;
118                if let Some(&w) = self.weights.get(w_idx) {
119                    *gi += go * w;
120                }
121            }
122        }
123
124        (grad_weights, grad_bias, grad_input)
125    }
126
127    /// Returns the total number of parameters
128    #[must_use]
129    pub fn num_params(&self) -> usize {
130        self.weights.len() + self.bias.len()
131    }
132
133    /// Gets all parameters as a flat vector
134    #[must_use]
135    pub fn get_params(&self) -> Vec<f32> {
136        let mut params = self.weights.clone();
137        params.extend(&self.bias);
138        params
139    }
140
141    /// Sets parameters from a flat vector
142    pub fn set_params(&mut self, params: &[f32]) {
143        let w_end = self.weights.len();
144        let b_len = self.bias.len();
145        if params.len() >= w_end + b_len {
146            self.weights.copy_from_slice(&params[..w_end]);
147            self.bias.copy_from_slice(&params[w_end..w_end + b_len]);
148        }
149    }
150}
151
152/// Cached activations for backpropagation
153#[derive(Debug, Clone)]
154pub struct ForwardCache {
155    /// Input to the network
156    pub input: Vec<f32>,
157    /// Hidden layer pre-activation
158    pub hidden_pre: Vec<f32>,
159    /// Hidden layer post-activation (after ReLU)
160    pub hidden_post: Vec<f32>,
161}
162
163/// Gradients for MLP
164#[derive(Debug, Clone)]
165pub struct MLPGradients {
166    /// Hidden layer weight gradients
167    pub hidden_weights: Vec<f32>,
168    /// Hidden layer bias gradients
169    pub hidden_bias: Vec<f32>,
170    /// Output layer weight gradients
171    pub output_weights: Vec<f32>,
172    /// Output layer bias gradients
173    pub output_bias: Vec<f32>,
174}
175
176impl MLPGradients {
177    /// Flatten all gradients into a single vector
178    #[must_use]
179    pub fn flatten(&self) -> Vec<f32> {
180        let mut flat = self.hidden_weights.clone();
181        flat.extend(&self.hidden_bias);
182        flat.extend(&self.output_weights);
183        flat.extend(&self.output_bias);
184        flat
185    }
186}
187
188/// A simple two-layer MLP for student model
189#[derive(Debug, Clone)]
190pub struct SimpleMLP {
191    /// Hidden layer
192    pub hidden: DenseLayer,
193    /// Output layer
194    pub output: DenseLayer,
195}
196
197impl SimpleMLP {
198    /// Creates a new simple MLP
199    #[must_use]
200    pub fn new(input_size: usize, hidden_size: usize, output_size: usize, seed: u64) -> Self {
201        Self {
202            hidden: DenseLayer::new(input_size, hidden_size, seed),
203            output: DenseLayer::new(hidden_size, output_size, seed.wrapping_add(1)),
204        }
205    }
206
207    /// Forward pass returning logits
208    #[must_use]
209    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
210        let hidden_out = self.hidden.forward(input);
211        // ReLU activation
212        let hidden_activated: Vec<f32> = hidden_out.iter().map(|&x| x.max(0.0)).collect();
213        self.output.forward(&hidden_activated)
214    }
215
216    /// Forward pass with cached activations for backprop
217    #[must_use]
218    pub fn forward_with_cache(&self, input: &[f32]) -> (Vec<f32>, ForwardCache) {
219        let hidden_pre = self.hidden.forward(input);
220        let hidden_post: Vec<f32> = hidden_pre.iter().map(|&x| x.max(0.0)).collect();
221        let output = self.output.forward(&hidden_post);
222
223        let cache = ForwardCache {
224            input: input.to_vec(),
225            hidden_pre,
226            hidden_post,
227        };
228
229        (output, cache)
230    }
231
232    /// Backward pass computing all gradients
233    pub fn backward(&self, grad_output: &[f32], cache: &ForwardCache) -> MLPGradients {
234        // Backward through output layer
235        let (grad_out_weights, grad_out_bias, grad_hidden) =
236            self.output.backward(&cache.hidden_post, grad_output);
237
238        // Backward through ReLU
239        let grad_hidden_pre: Vec<f32> = grad_hidden
240            .iter()
241            .zip(cache.hidden_pre.iter())
242            .map(|(&g, &h)| if h > 0.0 { g } else { 0.0 })
243            .collect();
244
245        // Backward through hidden layer
246        let (grad_hidden_weights, grad_hidden_bias, _) =
247            self.hidden.backward(&cache.input, &grad_hidden_pre);
248
249        MLPGradients {
250            hidden_weights: grad_hidden_weights,
251            hidden_bias: grad_hidden_bias,
252            output_weights: grad_out_weights,
253            output_bias: grad_out_bias,
254        }
255    }
256
257    /// Total number of parameters
258    #[must_use]
259    pub fn num_params(&self) -> usize {
260        self.hidden.num_params() + self.output.num_params()
261    }
262
263    /// Get all parameters as flat vector
264    #[must_use]
265    pub fn get_params(&self) -> Vec<f32> {
266        let mut params = self.hidden.get_params();
267        params.extend(self.output.get_params());
268        params
269    }
270
271    /// Set parameters from flat vector
272    pub fn set_params(&mut self, params: &[f32]) {
273        let hidden_size = self.hidden.num_params();
274        self.hidden.set_params(&params[..hidden_size]);
275        self.output.set_params(&params[hidden_size..]);
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_simple_rng() {
285        let mut rng = SimpleRng::new(42);
286
287        let val1 = rng.next_u64();
288
289        let mut rng2 = SimpleRng::new(42);
290        let val2 = rng2.next_u64();
291
292        assert_eq!(val1, val2);
293
294        let mut rng3 = SimpleRng::new(123);
295        for _ in 0..100 {
296            let f = rng3.next_f32();
297            assert!((0.0..1.0).contains(&f));
298        }
299    }
300
301    #[test]
302    fn test_dense_layer_forward() {
303        let layer = DenseLayer::new(4, 3, 42);
304        let input = vec![1.0, 2.0, 3.0, 4.0];
305        let output = layer.forward(&input);
306
307        assert_eq!(output.len(), 3);
308        for &o in &output {
309            assert!(o.is_finite());
310        }
311    }
312
313    #[test]
314    fn test_dense_layer_backward() {
315        let layer = DenseLayer::new(4, 3, 42);
316        let input = vec![1.0, 2.0, 3.0, 4.0];
317        let grad_output = vec![0.1, 0.2, 0.3];
318
319        let (grad_w, grad_b, grad_i) = layer.backward(&input, &grad_output);
320
321        assert_eq!(grad_w.len(), 4 * 3);
322        assert_eq!(grad_b.len(), 3);
323        assert_eq!(grad_i.len(), 4);
324    }
325
326    #[test]
327    fn test_simple_mlp_forward() {
328        let mlp = SimpleMLP::new(10, 20, 5, 42);
329        let input = vec![0.1; 10];
330        let output = mlp.forward(&input);
331
332        assert_eq!(output.len(), 5);
333        for &o in &output {
334            assert!(o.is_finite());
335        }
336    }
337
338    #[test]
339    fn test_simple_mlp_params() {
340        let mlp = SimpleMLP::new(10, 20, 5, 42);
341        let params = mlp.get_params();
342
343        // Should have (10*20 + 20) + (20*5 + 5) = 220 + 105 = 325 params
344        assert_eq!(params.len(), 325);
345        assert_eq!(mlp.num_params(), 325);
346    }
347}