Crate fast_neural_network
source ·Expand description
Fast Neural Network Library
This library is a simple neural network library written in Rust. It is designed to be fast and easy to use. It supports saving and loading networks to and from JSON files. All of the heavy operations are parallelized. Stack-based networks are yet to be implemented.
Example
use fast_neural_network::{activation::*, neural_network::*};
use ndarray::*;
fn main() {
let mut network = Network::new(2, 1, ActivationType::Tanh, 0.005); // Create a new network with 2 inputs, 1 output, a tanh activation function, and a learning rate of 0.005
network.add_hidden_layer_with_size(2); // Add a hidden layer with 2 neurons
network.compile(); // Compile the network to prepare it for training
// (will be done automatically during training)
// The API is exposed so that the user can compile
// the network on a different thread before training if they want to
// Let's create a dataset
let mut dataset: Vec<(ndarray::Array1<f64>, ndarray::Array1<f64>)> = Vec::new();
dataset.push((array!(0., 0.), array!(0.)));
dataset.push((array!(1., 0.), array!(1.)));
dataset.push((array!(0., 1.), array!(1.)));
dataset.push((array!(1., 1.), array!(0.)));
let untrained_prediction = network.forward(&input); // Predict the output of the network
println!("{:?}", trained_prediction);
network.train(&dataset, 50_000, 5_000); // train the network for 50,000 epochs with a decay_time of 5,000 epochs
let new_prediction = network.forward(&input); // Predict the output of the network again
let mut res;
// Let's check the result
for i in 0..dataset.len() {
res = network.forward(&dataset[i].0);
let d = &dataset[i];
println!(
"for [{:.3}, {:.3}], [{:.3}] -> [{:.3}]",
d.0[0], d.0[1], d.1[0], res
);
}
network.save("network.json"); // Save the model as a json to a file
let mut loaded_network = Network::load("network.json"); // Load the model from a json file
println!("{:?}", loaded_network.predict(&input));
}
Modules
- Activation functions and their derivatives.
- Neural Network