eenn/
nn.rs

1//! Neural Network Training Infrastructure
2//!
3//! Provides trainable neural networks with gradient computation,
4//! parameter storage, and optimization capabilities.
5
6use std::collections::HashMap;
7
8/// Trainable parameter that stores value and gradient
9#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
10pub struct Parameter {
11    pub value: f32,
12    #[serde(skip)] // Don't serialize gradients - they're transient
13    pub gradient: f32,
14}
15
16impl Parameter {
17    pub fn new(value: f32) -> Self {
18        Self {
19            value,
20            gradient: 0.0,
21        }
22    }
23
24    pub fn zero_grad(&mut self) {
25        self.gradient = 0.0;
26    }
27
28    pub fn update(&mut self, learning_rate: f32) {
29        self.value -= learning_rate * self.gradient;
30    }
31}
32
33/// Parameter store for managing trainable parameters
34#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
35pub struct ParameterStore {
36    params: HashMap<String, Parameter>,
37}
38
39impl ParameterStore {
40    pub fn new() -> Self {
41        Self {
42            params: HashMap::new(),
43        }
44    }
45
46    pub fn add_parameter(&mut self, name: &str, value: f32) -> &mut Parameter {
47        self.params.insert(name.to_string(), Parameter::new(value));
48        self.params.get_mut(name).unwrap()
49    }
50
51    pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
52        self.params.get(name)
53    }
54
55    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Parameter> {
56        self.params.get_mut(name)
57    }
58
59    pub fn zero_grad(&mut self) {
60        for param in self.params.values_mut() {
61            param.zero_grad();
62        }
63    }
64
65    pub fn update(&mut self, learning_rate: f32) {
66        for param in self.params.values_mut() {
67            param.update(learning_rate);
68        }
69    }
70
71    pub fn parameters(&self) -> &HashMap<String, Parameter> {
72        &self.params
73    }
74}
75
76/// Activation functions enum for neural networks
77#[derive(Clone, serde::Serialize, serde::Deserialize)]
78pub enum Activation {
79    ReLU,
80    Sigmoid,
81    Tanh,
82}
83
84impl Activation {
85    pub fn forward(&self, x: f32) -> f32 {
86        match self {
87            Activation::ReLU => x.max(0.0),
88            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
89            Activation::Tanh => x.tanh(),
90        }
91    }
92
93    pub fn backward(&self, x: f32) -> f32 {
94        match self {
95            Activation::ReLU => {
96                if x > 0.0 {
97                    1.0
98                } else {
99                    0.0
100                }
101            }
102            Activation::Sigmoid => {
103                let s = self.forward(x);
104                s * (1.0 - s)
105            }
106            Activation::Tanh => {
107                let t = self.forward(x);
108                1.0 - t * t
109            }
110        }
111    }
112}
113
114// Individual activation types for compatibility
115#[derive(Clone)]
116pub struct ReLU;
117
118#[derive(Clone)]
119pub struct Sigmoid;
120
121#[derive(Clone)]
122pub struct Tanh;
123
124/// Trainable linear layer
125#[derive(Clone, serde::Serialize, serde::Deserialize)]
126pub struct Linear {
127    weight_name: String,
128    bias_name: String,
129}
130
131impl Linear {
132    pub fn new(layer_id: usize, _input_size: usize, _output_size: usize) -> Self {
133        Self {
134            weight_name: format!("layer_{}_weight", layer_id),
135            bias_name: format!("layer_{}_bias", layer_id),
136        }
137    }
138
139    pub fn init_parameters(&self, params: &mut ParameterStore) {
140        use rand::Rng;
141        let mut rng = rand::rng();
142
143        // Xavier initialization
144        let weight_init: f32 = rng.random_range(-0.5..0.5);
145        let bias_init: f32 = rng.random_range(-0.1..0.1);
146
147        params.add_parameter(&self.weight_name, weight_init);
148        params.add_parameter(&self.bias_name, bias_init);
149    }
150
151    pub fn forward(&self, x: f32, params: &ParameterStore) -> f32 {
152        let weight = params.get_parameter(&self.weight_name).unwrap().value;
153        let bias = params.get_parameter(&self.bias_name).unwrap().value;
154        x * weight + bias
155    }
156
157    pub fn backward(&self, x: f32, grad_output: f32, params: &mut ParameterStore) -> f32 {
158        let weight = params.get_parameter(&self.weight_name).unwrap().value;
159
160        // Compute gradients
161        let weight_grad = x * grad_output;
162        let bias_grad = grad_output;
163        let input_grad = weight * grad_output;
164
165        // Accumulate gradients
166        params
167            .get_parameter_mut(&self.weight_name)
168            .unwrap()
169            .gradient += weight_grad;
170        params.get_parameter_mut(&self.bias_name).unwrap().gradient += bias_grad;
171
172        input_grad
173    }
174}
175
176/// Serializable neural network state for persistence
177#[derive(serde::Serialize, serde::Deserialize)]
178pub struct NeuralNetworkState {
179    pub layers: Vec<Linear>,
180    pub activations: Vec<Activation>,
181    pub params: ParameterStore,
182}
183
184/// Trainable neural network
185pub struct TrainableNeuron {
186    layers: Vec<Linear>,
187    activations: Vec<Activation>,
188    params: ParameterStore,
189    // Store intermediate values for backpropagation (not serialized)
190    layer_inputs: Vec<f32>,
191    layer_outputs: Vec<f32>,
192}
193
194impl TrainableNeuron {
195    pub fn new(layer_sizes: Vec<usize>) -> Self {
196        let mut layers = Vec::new();
197        let mut activations = Vec::new();
198        let mut params = ParameterStore::new();
199
200        // Create layers
201        for i in 0..layer_sizes.len() - 1 {
202            let layer = Linear::new(i, layer_sizes[i], layer_sizes[i + 1]);
203            layer.init_parameters(&mut params);
204            layers.push(layer);
205
206            // Add activation (ReLU for hidden layers, Sigmoid for output)
207            if i == layer_sizes.len() - 2 {
208                activations.push(Activation::Sigmoid);
209            } else {
210                activations.push(Activation::ReLU);
211            }
212        }
213
214        Self {
215            layers,
216            activations,
217            params,
218            layer_inputs: vec![0.0; layer_sizes.len()],
219            layer_outputs: vec![0.0; layer_sizes.len()],
220        }
221    }
222
223    pub fn forward(&mut self, mut x: f32) -> f32 {
224        self.layer_inputs[0] = x;
225        self.layer_outputs[0] = x;
226
227        for i in 0..self.layers.len() {
228            // Linear transformation
229            x = self.layers[i].forward(x, &self.params);
230            self.layer_inputs[i + 1] = x;
231
232            // Activation
233            x = self.activations[i].forward(x);
234            self.layer_outputs[i + 1] = x;
235        }
236
237        x
238    }
239
240    pub fn backward(&mut self, target: f32) -> f32 {
241        let output = self.layer_outputs[self.layer_outputs.len() - 1];
242
243        // Mean squared error loss and its gradient
244        let loss = 0.5 * (output - target).powi(2);
245        let mut grad_output = output - target;
246
247        // Backpropagate through layers (reverse order)
248        for i in (0..self.layers.len()).rev() {
249            // Gradient through activation
250            let pre_activation = self.layer_inputs[i + 1];
251            grad_output = grad_output * self.activations[i].backward(pre_activation);
252
253            // Gradient through linear layer
254            let layer_input = self.layer_outputs[i];
255            grad_output = self.layers[i].backward(layer_input, grad_output, &mut self.params);
256        }
257
258        loss
259    }
260
261    pub fn zero_grad(&mut self) {
262        self.params.zero_grad();
263    }
264
265    pub fn update_parameters(&mut self, learning_rate: f32) {
266        self.params.update(learning_rate);
267    }
268
269    pub fn parameters(&self) -> &ParameterStore {
270        &self.params
271    }
272
273    pub fn parameters_mut(&mut self) -> &mut ParameterStore {
274        &mut self.params
275    }
276
277    /// Save neural network weights to file
278    pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
279        let state = NeuralNetworkState {
280            layers: self.layers.clone(),
281            activations: self.activations.clone(),
282            params: self.params.clone(),
283        };
284
285        let file = std::fs::File::create(path)?;
286        serde_json::to_writer_pretty(file, &state)?;
287        Ok(())
288    }
289
290    /// Load neural network weights from file
291    pub fn load_from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
292        let file = std::fs::File::open(path)?;
293        let state: NeuralNetworkState = serde_json::from_reader(file)?;
294
295        // Reconstruct the neural network from saved state
296        let layer_count = state.layers.len() + 1; // +1 for input
297        Ok(Self {
298            layers: state.layers,
299            activations: state.activations,
300            params: state.params,
301            layer_inputs: vec![0.0; layer_count],
302            layer_outputs: vec![0.0; layer_count],
303        })
304    }
305
306    /// Create new network or load from file if it exists
307    pub fn new_or_load(
308        layer_sizes: Vec<usize>,
309        save_path: &std::path::Path,
310        verbose: bool,
311    ) -> Self {
312        if save_path.exists() {
313            match Self::load_from_file(save_path) {
314                Ok(network) => {
315                    if verbose {
316                        println!("🧠 Loaded existing neural network from {:?}", save_path);
317                    }
318                    return network;
319                }
320                Err(e) => {
321                    if verbose {
322                        println!(
323                            "⚠️ Failed to load network from {:?}: {}, creating new one",
324                            save_path, e
325                        );
326                    }
327                }
328            }
329        }
330
331        println!("🧠 Creating new neural network");
332        Self::new(layer_sizes)
333    }
334}