NeuralNetwork

Trait NeuralNetwork 

Source
pub trait NeuralNetwork<const I: usize, const O: usize> {
    type LayerOutputs: LayerOutputs<O>;

    // Required methods
    fn feedforward(&self, input: &RowVector<I>) -> Self::LayerOutputs;
    fn backpropagate(
        &mut self,
        input: &RowVector<I>,
        output: &RowVector<O>,
        layer_outputs: Self::LayerOutputs,
    ) -> f32;

    // Provided methods
    fn train_once<BI, BO>(&mut self, input: BI, output: BO) -> f32
       where BI: Borrow<RowVector<I>>,
             BO: Borrow<RowVector<O>> { ... }
    fn train<D, T>(
        &mut self,
        data: D,
        callback: Option<impl FnMut(usize, f32)>,
    ) -> (usize, f32)
       where D: IntoIterator<Item = T>,
             T: Borrow<(RowVector<I>, RowVector<O>)> { ... }
    fn predict(&self, input: &RowVector<I>) -> RowVector<O> { ... }
}
Expand description

A trait for neural networks, with I inputs and O outputs.

§Associated Types

  • LayerOutputs: The type representing the outputs of each layer in the network.

§Required methods

§Provided methods

Required Associated Types§

Source

type LayerOutputs: LayerOutputs<O>

The type representing the outputs of each layer in the network.

Required Methods§

Source

fn feedforward(&self, input: &RowVector<I>) -> Self::LayerOutputs

Feedforward the input through the network, updating layer outputs, and returning the final output.

Source

fn backpropagate( &mut self, input: &RowVector<I>, output: &RowVector<O>, layer_outputs: Self::LayerOutputs, ) -> f32

Perform backpropagation to adjust weights and biases based on the target output, returning the loss.

§Arguments
  • input: The input matrix of shape (1, I).
  • output: The target output matrix of shape (1, O).
  • layer_outputs: The outputs for each layer.

Provided Methods§

Source

fn train_once<BI, BO>(&mut self, input: BI, output: BO) -> f32
where BI: Borrow<RowVector<I>>, BO: Borrow<RowVector<O>>,

Train the neural network once with a pair of input and output, returning the average loss.

Source

fn train<D, T>( &mut self, data: D, callback: Option<impl FnMut(usize, f32)>, ) -> (usize, f32)
where D: IntoIterator<Item = T>, T: Borrow<(RowVector<I>, RowVector<O>)>,

Train the neural network with the provided data, calling an optional callback after each sample, returning the total number of samples and the final average loss.

Source

fn predict(&self, input: &RowVector<I>) -> RowVector<O>

Predict the output for a given input.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<const I: usize, const H: usize, const O: usize> NeuralNetwork<I, O> for SimpleNeuralNetwork<I, H, O>