use std::collections::HashMap;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Parameter {
pub value: f32,
#[serde(skip)] pub gradient: f32,
}
impl Parameter {
pub fn new(value: f32) -> Self {
Self {
value,
gradient: 0.0,
}
}
pub fn zero_grad(&mut self) {
self.gradient = 0.0;
}
pub fn update(&mut self, learning_rate: f32) {
self.value -= learning_rate * self.gradient;
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ParameterStore {
params: HashMap<String, Parameter>,
}
impl ParameterStore {
pub fn new() -> Self {
Self {
params: HashMap::new(),
}
}
pub fn add_parameter(&mut self, name: &str, value: f32) -> &mut Parameter {
self.params.insert(name.to_string(), Parameter::new(value));
self.params.get_mut(name).unwrap()
}
pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
self.params.get(name)
}
pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Parameter> {
self.params.get_mut(name)
}
pub fn zero_grad(&mut self) {
for param in self.params.values_mut() {
param.zero_grad();
}
}
pub fn update(&mut self, learning_rate: f32) {
for param in self.params.values_mut() {
param.update(learning_rate);
}
}
pub fn parameters(&self) -> &HashMap<String, Parameter> {
&self.params
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub enum Activation {
ReLU,
Sigmoid,
Tanh,
}
impl Activation {
pub fn forward(&self, x: f32) -> f32 {
match self {
Activation::ReLU => x.max(0.0),
Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
Activation::Tanh => x.tanh(),
}
}
pub fn backward(&self, x: f32) -> f32 {
match self {
Activation::ReLU => {
if x > 0.0 {
1.0
} else {
0.0
}
}
Activation::Sigmoid => {
let s = self.forward(x);
s * (1.0 - s)
}
Activation::Tanh => {
let t = self.forward(x);
1.0 - t * t
}
}
}
}
#[derive(Clone)]
pub struct ReLU;
#[derive(Clone)]
pub struct Sigmoid;
#[derive(Clone)]
pub struct Tanh;
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct Linear {
weight_name: String,
bias_name: String,
}
impl Linear {
pub fn new(layer_id: usize, _input_size: usize, _output_size: usize) -> Self {
Self {
weight_name: format!("layer_{}_weight", layer_id),
bias_name: format!("layer_{}_bias", layer_id),
}
}
pub fn init_parameters(&self, params: &mut ParameterStore) {
use rand::Rng;
let mut rng = rand::rng();
let weight_init: f32 = rng.random_range(-0.5..0.5);
let bias_init: f32 = rng.random_range(-0.1..0.1);
params.add_parameter(&self.weight_name, weight_init);
params.add_parameter(&self.bias_name, bias_init);
}
pub fn forward(&self, x: f32, params: &ParameterStore) -> f32 {
let weight = params.get_parameter(&self.weight_name).unwrap().value;
let bias = params.get_parameter(&self.bias_name).unwrap().value;
x * weight + bias
}
pub fn backward(&self, x: f32, grad_output: f32, params: &mut ParameterStore) -> f32 {
let weight = params.get_parameter(&self.weight_name).unwrap().value;
let weight_grad = x * grad_output;
let bias_grad = grad_output;
let input_grad = weight * grad_output;
params
.get_parameter_mut(&self.weight_name)
.unwrap()
.gradient += weight_grad;
params.get_parameter_mut(&self.bias_name).unwrap().gradient += bias_grad;
input_grad
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct NeuralNetworkState {
pub layers: Vec<Linear>,
pub activations: Vec<Activation>,
pub params: ParameterStore,
}
pub struct TrainableNeuron {
layers: Vec<Linear>,
activations: Vec<Activation>,
params: ParameterStore,
layer_inputs: Vec<f32>,
layer_outputs: Vec<f32>,
}
impl TrainableNeuron {
pub fn new(layer_sizes: Vec<usize>) -> Self {
let mut layers = Vec::new();
let mut activations = Vec::new();
let mut params = ParameterStore::new();
for i in 0..layer_sizes.len() - 1 {
let layer = Linear::new(i, layer_sizes[i], layer_sizes[i + 1]);
layer.init_parameters(&mut params);
layers.push(layer);
if i == layer_sizes.len() - 2 {
activations.push(Activation::Sigmoid);
} else {
activations.push(Activation::ReLU);
}
}
Self {
layers,
activations,
params,
layer_inputs: vec![0.0; layer_sizes.len()],
layer_outputs: vec![0.0; layer_sizes.len()],
}
}
pub fn forward(&mut self, mut x: f32) -> f32 {
self.layer_inputs[0] = x;
self.layer_outputs[0] = x;
for i in 0..self.layers.len() {
x = self.layers[i].forward(x, &self.params);
self.layer_inputs[i + 1] = x;
x = self.activations[i].forward(x);
self.layer_outputs[i + 1] = x;
}
x
}
pub fn backward(&mut self, target: f32) -> f32 {
let output = self.layer_outputs[self.layer_outputs.len() - 1];
let loss = 0.5 * (output - target).powi(2);
let mut grad_output = output - target;
for i in (0..self.layers.len()).rev() {
let pre_activation = self.layer_inputs[i + 1];
grad_output = grad_output * self.activations[i].backward(pre_activation);
let layer_input = self.layer_outputs[i];
grad_output = self.layers[i].backward(layer_input, grad_output, &mut self.params);
}
loss
}
pub fn zero_grad(&mut self) {
self.params.zero_grad();
}
pub fn update_parameters(&mut self, learning_rate: f32) {
self.params.update(learning_rate);
}
pub fn parameters(&self) -> &ParameterStore {
&self.params
}
pub fn parameters_mut(&mut self) -> &mut ParameterStore {
&mut self.params
}
pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
let state = NeuralNetworkState {
layers: self.layers.clone(),
activations: self.activations.clone(),
params: self.params.clone(),
};
let file = std::fs::File::create(path)?;
serde_json::to_writer_pretty(file, &state)?;
Ok(())
}
pub fn load_from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path)?;
let state: NeuralNetworkState = serde_json::from_reader(file)?;
let layer_count = state.layers.len() + 1; Ok(Self {
layers: state.layers,
activations: state.activations,
params: state.params,
layer_inputs: vec![0.0; layer_count],
layer_outputs: vec![0.0; layer_count],
})
}
pub fn new_or_load(
layer_sizes: Vec<usize>,
save_path: &std::path::Path,
verbose: bool,
) -> Self {
if save_path.exists() {
match Self::load_from_file(save_path) {
Ok(network) => {
if verbose {
println!("🧠 Loaded existing neural network from {:?}", save_path);
}
return network;
}
Err(e) => {
if verbose {
println!(
"⚠️ Failed to load network from {:?}: {}, creating new one",
save_path, e
);
}
}
}
}
println!("🧠 Creating new neural network");
Self::new(layer_sizes)
}
}