1use std::collections::HashMap;
7
8#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
10pub struct Parameter {
11 pub value: f32,
12 #[serde(skip)] pub gradient: f32,
14}
15
16impl Parameter {
17 pub fn new(value: f32) -> Self {
18 Self {
19 value,
20 gradient: 0.0,
21 }
22 }
23
24 pub fn zero_grad(&mut self) {
25 self.gradient = 0.0;
26 }
27
28 pub fn update(&mut self, learning_rate: f32) {
29 self.value -= learning_rate * self.gradient;
30 }
31}
32
33#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
35pub struct ParameterStore {
36 params: HashMap<String, Parameter>,
37}
38
39impl ParameterStore {
40 pub fn new() -> Self {
41 Self {
42 params: HashMap::new(),
43 }
44 }
45
46 pub fn add_parameter(&mut self, name: &str, value: f32) -> &mut Parameter {
47 self.params.insert(name.to_string(), Parameter::new(value));
48 self.params.get_mut(name).unwrap()
49 }
50
51 pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
52 self.params.get(name)
53 }
54
55 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Parameter> {
56 self.params.get_mut(name)
57 }
58
59 pub fn zero_grad(&mut self) {
60 for param in self.params.values_mut() {
61 param.zero_grad();
62 }
63 }
64
65 pub fn update(&mut self, learning_rate: f32) {
66 for param in self.params.values_mut() {
67 param.update(learning_rate);
68 }
69 }
70
71 pub fn parameters(&self) -> &HashMap<String, Parameter> {
72 &self.params
73 }
74}
75
76#[derive(Clone, serde::Serialize, serde::Deserialize)]
78pub enum Activation {
79 ReLU,
80 Sigmoid,
81 Tanh,
82}
83
84impl Activation {
85 pub fn forward(&self, x: f32) -> f32 {
86 match self {
87 Activation::ReLU => x.max(0.0),
88 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
89 Activation::Tanh => x.tanh(),
90 }
91 }
92
93 pub fn backward(&self, x: f32) -> f32 {
94 match self {
95 Activation::ReLU => {
96 if x > 0.0 {
97 1.0
98 } else {
99 0.0
100 }
101 }
102 Activation::Sigmoid => {
103 let s = self.forward(x);
104 s * (1.0 - s)
105 }
106 Activation::Tanh => {
107 let t = self.forward(x);
108 1.0 - t * t
109 }
110 }
111 }
112}
113
114#[derive(Clone)]
116pub struct ReLU;
117
118#[derive(Clone)]
119pub struct Sigmoid;
120
121#[derive(Clone)]
122pub struct Tanh;
123
124#[derive(Clone, serde::Serialize, serde::Deserialize)]
126pub struct Linear {
127 weight_name: String,
128 bias_name: String,
129}
130
131impl Linear {
132 pub fn new(layer_id: usize, _input_size: usize, _output_size: usize) -> Self {
133 Self {
134 weight_name: format!("layer_{}_weight", layer_id),
135 bias_name: format!("layer_{}_bias", layer_id),
136 }
137 }
138
139 pub fn init_parameters(&self, params: &mut ParameterStore) {
140 use rand::Rng;
141 let mut rng = rand::rng();
142
143 let weight_init: f32 = rng.random_range(-0.5..0.5);
145 let bias_init: f32 = rng.random_range(-0.1..0.1);
146
147 params.add_parameter(&self.weight_name, weight_init);
148 params.add_parameter(&self.bias_name, bias_init);
149 }
150
151 pub fn forward(&self, x: f32, params: &ParameterStore) -> f32 {
152 let weight = params.get_parameter(&self.weight_name).unwrap().value;
153 let bias = params.get_parameter(&self.bias_name).unwrap().value;
154 x * weight + bias
155 }
156
157 pub fn backward(&self, x: f32, grad_output: f32, params: &mut ParameterStore) -> f32 {
158 let weight = params.get_parameter(&self.weight_name).unwrap().value;
159
160 let weight_grad = x * grad_output;
162 let bias_grad = grad_output;
163 let input_grad = weight * grad_output;
164
165 params
167 .get_parameter_mut(&self.weight_name)
168 .unwrap()
169 .gradient += weight_grad;
170 params.get_parameter_mut(&self.bias_name).unwrap().gradient += bias_grad;
171
172 input_grad
173 }
174}
175
176#[derive(serde::Serialize, serde::Deserialize)]
178pub struct NeuralNetworkState {
179 pub layers: Vec<Linear>,
180 pub activations: Vec<Activation>,
181 pub params: ParameterStore,
182}
183
184pub struct TrainableNeuron {
186 layers: Vec<Linear>,
187 activations: Vec<Activation>,
188 params: ParameterStore,
189 layer_inputs: Vec<f32>,
191 layer_outputs: Vec<f32>,
192}
193
194impl TrainableNeuron {
195 pub fn new(layer_sizes: Vec<usize>) -> Self {
196 let mut layers = Vec::new();
197 let mut activations = Vec::new();
198 let mut params = ParameterStore::new();
199
200 for i in 0..layer_sizes.len() - 1 {
202 let layer = Linear::new(i, layer_sizes[i], layer_sizes[i + 1]);
203 layer.init_parameters(&mut params);
204 layers.push(layer);
205
206 if i == layer_sizes.len() - 2 {
208 activations.push(Activation::Sigmoid);
209 } else {
210 activations.push(Activation::ReLU);
211 }
212 }
213
214 Self {
215 layers,
216 activations,
217 params,
218 layer_inputs: vec![0.0; layer_sizes.len()],
219 layer_outputs: vec![0.0; layer_sizes.len()],
220 }
221 }
222
223 pub fn forward(&mut self, mut x: f32) -> f32 {
224 self.layer_inputs[0] = x;
225 self.layer_outputs[0] = x;
226
227 for i in 0..self.layers.len() {
228 x = self.layers[i].forward(x, &self.params);
230 self.layer_inputs[i + 1] = x;
231
232 x = self.activations[i].forward(x);
234 self.layer_outputs[i + 1] = x;
235 }
236
237 x
238 }
239
240 pub fn backward(&mut self, target: f32) -> f32 {
241 let output = self.layer_outputs[self.layer_outputs.len() - 1];
242
243 let loss = 0.5 * (output - target).powi(2);
245 let mut grad_output = output - target;
246
247 for i in (0..self.layers.len()).rev() {
249 let pre_activation = self.layer_inputs[i + 1];
251 grad_output = grad_output * self.activations[i].backward(pre_activation);
252
253 let layer_input = self.layer_outputs[i];
255 grad_output = self.layers[i].backward(layer_input, grad_output, &mut self.params);
256 }
257
258 loss
259 }
260
261 pub fn zero_grad(&mut self) {
262 self.params.zero_grad();
263 }
264
265 pub fn update_parameters(&mut self, learning_rate: f32) {
266 self.params.update(learning_rate);
267 }
268
269 pub fn parameters(&self) -> &ParameterStore {
270 &self.params
271 }
272
273 pub fn parameters_mut(&mut self) -> &mut ParameterStore {
274 &mut self.params
275 }
276
277 pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
279 let state = NeuralNetworkState {
280 layers: self.layers.clone(),
281 activations: self.activations.clone(),
282 params: self.params.clone(),
283 };
284
285 let file = std::fs::File::create(path)?;
286 serde_json::to_writer_pretty(file, &state)?;
287 Ok(())
288 }
289
290 pub fn load_from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
292 let file = std::fs::File::open(path)?;
293 let state: NeuralNetworkState = serde_json::from_reader(file)?;
294
295 let layer_count = state.layers.len() + 1; Ok(Self {
298 layers: state.layers,
299 activations: state.activations,
300 params: state.params,
301 layer_inputs: vec![0.0; layer_count],
302 layer_outputs: vec![0.0; layer_count],
303 })
304 }
305
306 pub fn new_or_load(
308 layer_sizes: Vec<usize>,
309 save_path: &std::path::Path,
310 verbose: bool,
311 ) -> Self {
312 if save_path.exists() {
313 match Self::load_from_file(save_path) {
314 Ok(network) => {
315 if verbose {
316 println!("🧠 Loaded existing neural network from {:?}", save_path);
317 }
318 return network;
319 }
320 Err(e) => {
321 if verbose {
322 println!(
323 "⚠️ Failed to load network from {:?}: {}, creating new one",
324 save_path, e
325 );
326 }
327 }
328 }
329 }
330
331 println!("🧠 Creating new neural network");
332 Self::new(layer_sizes)
333 }
334}