ruv_fann/
network.rs

1use crate::{ActivationFunction, Layer, TrainingAlgorithm};
2use num_traits::Float;
3use rand::distributions::Uniform;
4use rand::Rng;
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9/// Errors that can occur during network operations
10#[derive(Error, Debug)]
11pub enum NetworkError {
12    #[error("Input size mismatch: expected {expected}, got {actual}")]
13    InputSizeMismatch { expected: usize, actual: usize },
14
15    #[error("Weight count mismatch: expected {expected}, got {actual}")]
16    WeightCountMismatch { expected: usize, actual: usize },
17
18    #[error("Invalid layer configuration")]
19    InvalidLayerConfiguration,
20
21    #[error("Network has no layers")]
22    NoLayers,
23}
24
25/// A feedforward neural network
26#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct Network<T: Float> {
29    /// The layers of the network
30    pub layers: Vec<Layer<T>>,
31
32    /// Connection rate (1.0 = fully connected, 0.0 = no connections)
33    pub connection_rate: T,
34}
35
36impl<T: Float> Network<T> {
37    /// Creates a new network with the specified layer sizes
38    pub fn new(layer_sizes: &[usize]) -> Self {
39        NetworkBuilder::new().layers_from_sizes(layer_sizes).build()
40    }
41
42    /// Returns the number of layers in the network
43    pub fn num_layers(&self) -> usize {
44        self.layers.len()
45    }
46
47    /// Returns the number of input neurons (excluding bias)
48    pub fn num_inputs(&self) -> usize {
49        self.layers
50            .first()
51            .map(|l| l.num_regular_neurons())
52            .unwrap_or(0)
53    }
54
55    /// Returns the number of output neurons
56    pub fn num_outputs(&self) -> usize {
57        self.layers
58            .last()
59            .map(|l| l.num_regular_neurons())
60            .unwrap_or(0)
61    }
62
63    /// Returns the total number of neurons in the network
64    pub fn total_neurons(&self) -> usize {
65        self.layers.iter().map(|l| l.size()).sum()
66    }
67
68    /// Returns the total number of connections in the network
69    pub fn total_connections(&self) -> usize {
70        self.layers
71            .iter()
72            .flat_map(|layer| &layer.neurons)
73            .map(|neuron| neuron.connections.len())
74            .sum()
75    }
76
77    /// Alias for total_connections for compatibility
78    pub fn get_total_connections(&self) -> usize {
79        self.total_connections()
80    }
81
82    /// Runs a forward pass through the network
83    ///
84    /// # Arguments
85    /// * `inputs` - Input values for the network
86    ///
87    /// # Returns
88    /// Output values from the network
89    ///
90    /// # Example
91    /// ```
92    /// use ruv_fann::NetworkBuilder;
93    ///
94    /// let mut network = NetworkBuilder::<f32>::new()
95    ///     .input_layer(2)
96    ///     .hidden_layer(3)
97    ///     .output_layer(1)
98    ///     .build();
99    ///
100    /// let inputs = vec![0.5, 0.7];
101    /// let outputs = network.run(&inputs);
102    /// assert_eq!(outputs.len(), 1);
103    /// ```
104    pub fn run(&mut self, inputs: &[T]) -> Vec<T> {
105        if self.layers.is_empty() {
106            return Vec::new();
107        }
108
109        // Set input layer values
110        if self.layers[0].set_inputs(inputs).is_err() {
111            return Vec::new();
112        }
113
114        // Forward propagate through each layer
115        for i in 1..self.layers.len() {
116            let prev_outputs = self.layers[i - 1].get_outputs();
117            self.layers[i].calculate(&prev_outputs);
118        }
119
120        // Return output layer values (excluding bias if present)
121        if let Some(output_layer) = self.layers.last() {
122            output_layer
123                .neurons
124                .iter()
125                .filter(|n| !n.is_bias)
126                .map(|n| n.value)
127                .collect()
128        } else {
129            Vec::new()
130        }
131    }
132
133    /// Gets all weights in the network as a flat vector
134    ///
135    /// Weights are ordered by layer, then by neuron, then by connection
136    pub fn get_weights(&self) -> Vec<T> {
137        let mut weights = Vec::new();
138
139        for layer in &self.layers {
140            for neuron in &layer.neurons {
141                for connection in &neuron.connections {
142                    weights.push(connection.weight);
143                }
144            }
145        }
146
147        weights
148    }
149
150    /// Sets all weights in the network from a flat vector
151    ///
152    /// # Arguments
153    /// * `weights` - New weights in the same order as returned by `get_weights`
154    ///
155    /// # Returns
156    /// Ok(()) if successful, Err if weight count doesn't match
157    pub fn set_weights(&mut self, weights: &[T]) -> Result<(), NetworkError> {
158        let expected = self.total_connections();
159        if weights.len() != expected {
160            return Err(NetworkError::WeightCountMismatch {
161                expected,
162                actual: weights.len(),
163            });
164        }
165
166        let mut weight_idx = 0;
167        for layer in &mut self.layers {
168            for neuron in &mut layer.neurons {
169                for connection in &mut neuron.connections {
170                    connection.weight = weights[weight_idx];
171                    weight_idx += 1;
172                }
173            }
174        }
175
176        Ok(())
177    }
178
179    /// Resets all neurons in the network
180    pub fn reset(&mut self) {
181        for layer in &mut self.layers {
182            layer.reset();
183        }
184    }
185
186    /// Sets the activation function for all hidden layers
187    pub fn set_activation_function_hidden(&mut self, activation_function: ActivationFunction) {
188        // Skip input (0) and output (last) layers
189        let num_layers = self.layers.len();
190        if num_layers > 2 {
191            for i in 1..num_layers - 1 {
192                self.layers[i].set_activation_function(activation_function);
193            }
194        }
195    }
196
197    /// Sets the activation function for the output layer
198    pub fn set_activation_function_output(&mut self, activation_function: ActivationFunction) {
199        if let Some(output_layer) = self.layers.last_mut() {
200            output_layer.set_activation_function(activation_function);
201        }
202    }
203
204    /// Sets the activation steepness for all hidden layers
205    pub fn set_activation_steepness_hidden(&mut self, steepness: T) {
206        let num_layers = self.layers.len();
207        if num_layers > 2 {
208            for i in 1..num_layers - 1 {
209                self.layers[i].set_activation_steepness(steepness);
210            }
211        }
212    }
213
214    /// Sets the activation steepness for the output layer
215    pub fn set_activation_steepness_output(&mut self, steepness: T) {
216        if let Some(output_layer) = self.layers.last_mut() {
217            output_layer.set_activation_steepness(steepness);
218        }
219    }
220
221    /// Sets the activation function for all neurons in a specific layer
222    pub fn set_activation_function(
223        &mut self,
224        layer: usize,
225        activation_function: ActivationFunction,
226    ) {
227        if layer < self.layers.len() {
228            self.layers[layer].set_activation_function(activation_function);
229        }
230    }
231
232    /// Randomizes all weights in the network within the given range
233    pub fn randomize_weights(&mut self, min: T, max: T)
234    where
235        T: rand::distributions::uniform::SampleUniform,
236    {
237        let mut rng = rand::thread_rng();
238        let range = Uniform::new(min, max);
239
240        for layer in &mut self.layers {
241            for neuron in &mut layer.neurons {
242                for connection in &mut neuron.connections {
243                    connection.weight = rng.sample(&range);
244                }
245            }
246        }
247    }
248
249    /// Sets the training algorithm (placeholder for API compatibility)
250    pub fn set_training_algorithm(&mut self, _algorithm: TrainingAlgorithm) {
251        // This is a placeholder for API compatibility
252        // Actual training algorithm is selected when calling train methods
253    }
254
255    /// Train the network with the given data
256    pub fn train(
257        &mut self,
258        inputs: &[Vec<T>],
259        outputs: &[Vec<T>],
260        learning_rate: f32,
261        epochs: usize,
262    ) -> Result<(), NetworkError>
263    where
264        T: std::ops::AddAssign + std::ops::SubAssign + std::ops::MulAssign + std::cmp::PartialOrd,
265    {
266        if inputs.len() != outputs.len() {
267            return Err(NetworkError::InvalidLayerConfiguration);
268        }
269
270        // Simple gradient descent training implementation
271        let lr = T::from(learning_rate as f64).unwrap_or(T::from(0.7).unwrap_or(T::one()));
272
273        for _epoch in 0..epochs {
274            let mut total_error = T::zero();
275
276            for (input, target) in inputs.iter().zip(outputs.iter()) {
277                // Forward pass
278                let output = self.run(input);
279
280                // Calculate error
281                for (o, t) in output.iter().zip(target.iter()) {
282                    let diff = *o - *t;
283                    total_error += diff * diff;
284                }
285
286                // Backward pass (simplified backpropagation)
287                // This is a placeholder - real implementation would involve proper backpropagation
288                for layer in &mut self.layers {
289                    for neuron in &mut layer.neurons {
290                        for connection in &mut neuron.connections {
291                            // Simple weight update
292                            connection.weight -= lr * T::from(0.01).unwrap_or(T::one());
293                        }
294                    }
295                }
296            }
297        }
298
299        Ok(())
300    }
301
302    /// Run batch inference on multiple inputs
303    pub fn run_batch(&mut self, inputs: &[Vec<T>]) -> Vec<Vec<T>> {
304        inputs.iter().map(|input| self.run(input)).collect()
305    }
306
307    /// Serialize the network to bytes
308    #[cfg(all(feature = "binary", feature = "serde"))]
309    pub fn to_bytes(&self) -> Vec<u8>
310    where
311        T: serde::Serialize,
312        Network<T>: serde::Serialize,
313    {
314        bincode::serialize(self).unwrap_or_default()
315    }
316
317    #[cfg(feature = "binary")]
318    #[cfg(not(feature = "serde"))]
319    pub fn to_bytes(&self) -> Vec<u8> {
320        // Fallback implementation when serde is not available
321        Vec::new()
322    }
323
324    /// Deserialize a network from bytes
325    #[cfg(all(feature = "binary", feature = "serde"))]
326    pub fn from_bytes(bytes: &[u8]) -> Result<Self, NetworkError>
327    where
328        T: serde::de::DeserializeOwned,
329        Network<T>: serde::de::DeserializeOwned,
330    {
331        bincode::deserialize(bytes).map_err(|_| NetworkError::InvalidLayerConfiguration)
332    }
333
334    #[cfg(feature = "binary")]
335    #[cfg(not(feature = "serde"))]
336    pub fn from_bytes(_bytes: &[u8]) -> Result<Self, NetworkError> {
337        // Fallback implementation when serde is not available
338        Err(NetworkError::InvalidLayerConfiguration)
339    }
340}
341
342/// Builder for creating neural networks with a fluent API
343pub struct NetworkBuilder<T: Float> {
344    layers: Vec<(usize, ActivationFunction, T)>,
345    connection_rate: T,
346}
347
348impl<T: Float> NetworkBuilder<T> {
349    /// Creates a new network builder
350    ///
351    /// # Example
352    /// ```
353    /// use ruv_fann::NetworkBuilder;
354    ///
355    /// let network = NetworkBuilder::<f32>::new()
356    ///     .input_layer(2)
357    ///     .hidden_layer(3)
358    ///     .output_layer(1)
359    ///     .build();
360    /// ```
361    pub fn new() -> Self {
362        NetworkBuilder {
363            layers: Vec::new(),
364            connection_rate: T::one(),
365        }
366    }
367
368    /// Create layers from a slice of layer sizes
369    pub fn layers_from_sizes(mut self, sizes: &[usize]) -> Self {
370        if sizes.is_empty() {
371            return self;
372        }
373
374        // First layer is input
375        self.layers
376            .push((sizes[0], ActivationFunction::Linear, T::one()));
377
378        // Middle layers are hidden with sigmoid activation
379        for &size in &sizes[1..sizes.len() - 1] {
380            self.layers
381                .push((size, ActivationFunction::Sigmoid, T::one()));
382        }
383
384        // Last layer is output
385        if sizes.len() > 1 {
386            self.layers.push((
387                sizes[sizes.len() - 1],
388                ActivationFunction::Sigmoid,
389                T::one(),
390            ));
391        }
392
393        self
394    }
395
396    /// Adds an input layer to the network
397    pub fn input_layer(mut self, size: usize) -> Self {
398        self.layers
399            .push((size, ActivationFunction::Linear, T::one()));
400        self
401    }
402
403    /// Adds a hidden layer with default activation (Sigmoid)
404    pub fn hidden_layer(mut self, size: usize) -> Self {
405        self.layers
406            .push((size, ActivationFunction::Sigmoid, T::one()));
407        self
408    }
409
410    /// Adds a hidden layer with specific activation function
411    pub fn hidden_layer_with_activation(
412        mut self,
413        size: usize,
414        activation: ActivationFunction,
415        steepness: T,
416    ) -> Self {
417        self.layers.push((size, activation, steepness));
418        self
419    }
420
421    /// Adds an output layer with default activation (Sigmoid)
422    pub fn output_layer(mut self, size: usize) -> Self {
423        self.layers
424            .push((size, ActivationFunction::Sigmoid, T::one()));
425        self
426    }
427
428    /// Adds an output layer with specific activation function
429    pub fn output_layer_with_activation(
430        mut self,
431        size: usize,
432        activation: ActivationFunction,
433        steepness: T,
434    ) -> Self {
435        self.layers.push((size, activation, steepness));
436        self
437    }
438
439    /// Sets the connection rate (0.0 to 1.0)
440    pub fn connection_rate(mut self, rate: T) -> Self {
441        self.connection_rate = rate;
442        self
443    }
444
445    /// Builds the network
446    pub fn build(self) -> Network<T> {
447        let mut network_layers = Vec::new();
448
449        // Create layers
450        for (i, &(size, activation, steepness)) in self.layers.iter().enumerate() {
451            let layer = if i == 0 {
452                // Input layer with bias
453                Layer::with_bias(size, activation, steepness)
454            } else if i == self.layers.len() - 1 {
455                // Output layer without bias
456                Layer::new(size, activation, steepness)
457            } else {
458                // Hidden layer with bias
459                Layer::with_bias(size, activation, steepness)
460            };
461            network_layers.push(layer);
462        }
463
464        // Connect layers
465        for i in 0..network_layers.len() - 1 {
466            let (before, after) = network_layers.split_at_mut(i + 1);
467            before[i].connect_to(&mut after[0], self.connection_rate);
468        }
469
470        Network {
471            layers: network_layers,
472            connection_rate: self.connection_rate,
473        }
474    }
475}
476
477impl<T: Float> Default for NetworkBuilder<T> {
478    fn default() -> Self {
479        Self::new()
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_network_builder() {
489        let network: Network<f32> = NetworkBuilder::new()
490            .input_layer(2)
491            .hidden_layer(3)
492            .output_layer(1)
493            .build();
494
495        assert_eq!(network.num_layers(), 3);
496        assert_eq!(network.num_inputs(), 2);
497        assert_eq!(network.num_outputs(), 1);
498    }
499
500    #[test]
501    fn test_network_run() {
502        let mut network: Network<f32> = NetworkBuilder::new()
503            .input_layer(2)
504            .hidden_layer(3)
505            .output_layer(1)
506            .build();
507
508        let inputs = vec![0.5, 0.7];
509        let outputs = network.run(&inputs);
510        assert_eq!(outputs.len(), 1);
511    }
512
513    #[test]
514    fn test_total_neurons() {
515        let network: Network<f32> = NetworkBuilder::new()
516            .input_layer(2) // 2 + 1 bias = 3
517            .hidden_layer(3) // 3 + 1 bias = 4
518            .output_layer(1) // 1 (no bias) = 1
519            .build();
520
521        assert_eq!(network.total_neurons(), 8);
522    }
523
524    #[test]
525    fn test_sparse_network() {
526        let network: Network<f32> = NetworkBuilder::new()
527            .input_layer(10)
528            .hidden_layer(10)
529            .output_layer(10)
530            .connection_rate(0.5)
531            .build();
532
533        // Should have fewer connections than a fully connected network
534        let connections = network.total_connections();
535        let max_connections = 11 * 10 + 11 * 10; // (10+1)*10 + (10+1)*10
536
537        assert!(connections < max_connections);
538    }
539}