eenn 0.1.0

A hybrid neural-symbolic constraint solver with cognitive reasoning capabilities
Documentation
//! Neural Network Training Infrastructure
//!
//! Provides trainable neural networks with gradient computation,
//! parameter storage, and optimization capabilities.

use std::collections::HashMap;

/// Trainable parameter that stores value and gradient
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Parameter {
    pub value: f32,
    #[serde(skip)] // Don't serialize gradients - they're transient
    pub gradient: f32,
}

impl Parameter {
    pub fn new(value: f32) -> Self {
        Self {
            value,
            gradient: 0.0,
        }
    }

    pub fn zero_grad(&mut self) {
        self.gradient = 0.0;
    }

    pub fn update(&mut self, learning_rate: f32) {
        self.value -= learning_rate * self.gradient;
    }
}

/// Parameter store for managing trainable parameters
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ParameterStore {
    params: HashMap<String, Parameter>,
}

impl ParameterStore {
    pub fn new() -> Self {
        Self {
            params: HashMap::new(),
        }
    }

    pub fn add_parameter(&mut self, name: &str, value: f32) -> &mut Parameter {
        self.params.insert(name.to_string(), Parameter::new(value));
        self.params.get_mut(name).unwrap()
    }

    pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
        self.params.get(name)
    }

    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Parameter> {
        self.params.get_mut(name)
    }

    pub fn zero_grad(&mut self) {
        for param in self.params.values_mut() {
            param.zero_grad();
        }
    }

    pub fn update(&mut self, learning_rate: f32) {
        for param in self.params.values_mut() {
            param.update(learning_rate);
        }
    }

    pub fn parameters(&self) -> &HashMap<String, Parameter> {
        &self.params
    }
}

/// Activation functions enum for neural networks
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub enum Activation {
    ReLU,
    Sigmoid,
    Tanh,
}

impl Activation {
    pub fn forward(&self, x: f32) -> f32 {
        match self {
            Activation::ReLU => x.max(0.0),
            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
            Activation::Tanh => x.tanh(),
        }
    }

    pub fn backward(&self, x: f32) -> f32 {
        match self {
            Activation::ReLU => {
                if x > 0.0 {
                    1.0
                } else {
                    0.0
                }
            }
            Activation::Sigmoid => {
                let s = self.forward(x);
                s * (1.0 - s)
            }
            Activation::Tanh => {
                let t = self.forward(x);
                1.0 - t * t
            }
        }
    }
}

// Individual activation types for compatibility
#[derive(Clone)]
pub struct ReLU;

#[derive(Clone)]
pub struct Sigmoid;

#[derive(Clone)]
pub struct Tanh;

/// Trainable linear layer
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct Linear {
    weight_name: String,
    bias_name: String,
}

impl Linear {
    pub fn new(layer_id: usize, _input_size: usize, _output_size: usize) -> Self {
        Self {
            weight_name: format!("layer_{}_weight", layer_id),
            bias_name: format!("layer_{}_bias", layer_id),
        }
    }

    pub fn init_parameters(&self, params: &mut ParameterStore) {
        use rand::Rng;
        let mut rng = rand::rng();

        // Xavier initialization
        let weight_init: f32 = rng.random_range(-0.5..0.5);
        let bias_init: f32 = rng.random_range(-0.1..0.1);

        params.add_parameter(&self.weight_name, weight_init);
        params.add_parameter(&self.bias_name, bias_init);
    }

    pub fn forward(&self, x: f32, params: &ParameterStore) -> f32 {
        let weight = params.get_parameter(&self.weight_name).unwrap().value;
        let bias = params.get_parameter(&self.bias_name).unwrap().value;
        x * weight + bias
    }

    pub fn backward(&self, x: f32, grad_output: f32, params: &mut ParameterStore) -> f32 {
        let weight = params.get_parameter(&self.weight_name).unwrap().value;

        // Compute gradients
        let weight_grad = x * grad_output;
        let bias_grad = grad_output;
        let input_grad = weight * grad_output;

        // Accumulate gradients
        params
            .get_parameter_mut(&self.weight_name)
            .unwrap()
            .gradient += weight_grad;
        params.get_parameter_mut(&self.bias_name).unwrap().gradient += bias_grad;

        input_grad
    }
}

/// Serializable neural network state for persistence
#[derive(serde::Serialize, serde::Deserialize)]
pub struct NeuralNetworkState {
    pub layers: Vec<Linear>,
    pub activations: Vec<Activation>,
    pub params: ParameterStore,
}

/// Trainable neural network
pub struct TrainableNeuron {
    layers: Vec<Linear>,
    activations: Vec<Activation>,
    params: ParameterStore,
    // Store intermediate values for backpropagation (not serialized)
    layer_inputs: Vec<f32>,
    layer_outputs: Vec<f32>,
}

impl TrainableNeuron {
    pub fn new(layer_sizes: Vec<usize>) -> Self {
        let mut layers = Vec::new();
        let mut activations = Vec::new();
        let mut params = ParameterStore::new();

        // Create layers
        for i in 0..layer_sizes.len() - 1 {
            let layer = Linear::new(i, layer_sizes[i], layer_sizes[i + 1]);
            layer.init_parameters(&mut params);
            layers.push(layer);

            // Add activation (ReLU for hidden layers, Sigmoid for output)
            if i == layer_sizes.len() - 2 {
                activations.push(Activation::Sigmoid);
            } else {
                activations.push(Activation::ReLU);
            }
        }

        Self {
            layers,
            activations,
            params,
            layer_inputs: vec![0.0; layer_sizes.len()],
            layer_outputs: vec![0.0; layer_sizes.len()],
        }
    }

    pub fn forward(&mut self, mut x: f32) -> f32 {
        self.layer_inputs[0] = x;
        self.layer_outputs[0] = x;

        for i in 0..self.layers.len() {
            // Linear transformation
            x = self.layers[i].forward(x, &self.params);
            self.layer_inputs[i + 1] = x;

            // Activation
            x = self.activations[i].forward(x);
            self.layer_outputs[i + 1] = x;
        }

        x
    }

    pub fn backward(&mut self, target: f32) -> f32 {
        let output = self.layer_outputs[self.layer_outputs.len() - 1];

        // Mean squared error loss and its gradient
        let loss = 0.5 * (output - target).powi(2);
        let mut grad_output = output - target;

        // Backpropagate through layers (reverse order)
        for i in (0..self.layers.len()).rev() {
            // Gradient through activation
            let pre_activation = self.layer_inputs[i + 1];
            grad_output = grad_output * self.activations[i].backward(pre_activation);

            // Gradient through linear layer
            let layer_input = self.layer_outputs[i];
            grad_output = self.layers[i].backward(layer_input, grad_output, &mut self.params);
        }

        loss
    }

    pub fn zero_grad(&mut self) {
        self.params.zero_grad();
    }

    pub fn update_parameters(&mut self, learning_rate: f32) {
        self.params.update(learning_rate);
    }

    pub fn parameters(&self) -> &ParameterStore {
        &self.params
    }

    pub fn parameters_mut(&mut self) -> &mut ParameterStore {
        &mut self.params
    }

    /// Save neural network weights to file
    pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
        let state = NeuralNetworkState {
            layers: self.layers.clone(),
            activations: self.activations.clone(),
            params: self.params.clone(),
        };

        let file = std::fs::File::create(path)?;
        serde_json::to_writer_pretty(file, &state)?;
        Ok(())
    }

    /// Load neural network weights from file
    pub fn load_from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
        let file = std::fs::File::open(path)?;
        let state: NeuralNetworkState = serde_json::from_reader(file)?;

        // Reconstruct the neural network from saved state
        let layer_count = state.layers.len() + 1; // +1 for input
        Ok(Self {
            layers: state.layers,
            activations: state.activations,
            params: state.params,
            layer_inputs: vec![0.0; layer_count],
            layer_outputs: vec![0.0; layer_count],
        })
    }

    /// Create new network or load from file if it exists
    pub fn new_or_load(
        layer_sizes: Vec<usize>,
        save_path: &std::path::Path,
        verbose: bool,
    ) -> Self {
        if save_path.exists() {
            match Self::load_from_file(save_path) {
                Ok(network) => {
                    if verbose {
                        println!("🧠 Loaded existing neural network from {:?}", save_path);
                    }
                    return network;
                }
                Err(e) => {
                    if verbose {
                        println!(
                            "⚠️ Failed to load network from {:?}: {}, creating new one",
                            save_path, e
                        );
                    }
                }
            }
        }

        println!("🧠 Creating new neural network");
        Self::new(layer_sizes)
    }
}