Rust-Neural-Network
This is the beginning of a neural network library written in Rust, designed to provide a flexible and efficient platform for building and training neural networks. The current implementation requires that everything is allocated on the heap and be computed partially. This will be changed once Rust has a better implementation of Generics in Constant Expressions.
The nightly compiler is not something I don't wanna bother with.
Project Status
Current Version: 0.2.0 (beta)
This library is still in its early development stages, and the current version is in the beta stage and will jump to a 1.0.0 version once stack-based allocations are implemented.
Contributions and feedback are welcome, but please be aware that the internal structure may undergo significant changes as the library matures, so don't depend on the internal Matrix
implementation as it will most likely change.
Features
- Basic Neural Network Layers: The library currently supports fundamental neural network layers such as fully connected (dense) layers, convolutional layers (I call them hidden layers in the API), and activation
- functions (ReLU, sigmoid, etc.).
- Backpropagation: The library includes an implementation of backpropagation, which is crucial for training neural networks. This allows the network to learn from data and update its weights accordingly.
- Model Serialization: I plan to support model serialization to allow users to save and load trained models easily. I have started testing
sendre
in the meantime. - Documentation: I'll write the documentation once I'm ready to publish to crates.io. For now, the example left in the
main.rs
file should be more than enough.
But, here's the example of creating a simple feedforward neural network using the library, just for those who don't have the time to browse the file.:
use *;
use *;
use *;
let mut network = empty_network;
network.add_hidden_layer_with_size;
network.add_hidden_layer_with_size;
network.compile; // Compile the network to prepare it for training (will be done automatically during training)
// The API is exposed so that the user can compile the network on a different thread before training if they want to
// setting up the weights and biases of the network manually
let layer_1_weights = from_vec;
let layer_1_biases = from_vec;
let layer_2_weights = from_vec;
let layer_2_biases = from_vec;
let layer_3_weights = from_vec;
let layer_3_biases = from_vec;
network.set_layer_weights;
network.set_layer_biases;
network.set_layer_weights;
network.set_layer_biases;
network.set_layer_weights;
network.set_layer_biases;
// defining the input for the itteration
let input: = vec!;
let prediction = network.forward_propagate; // Predict the output of the network
let error = network.back_propagate; // Backpropagate the input with a target output of 9.0
let new_prediction = network.forward_propagate; // Predict the output of the network again
println!;
println!;
network.save; // Save the model as a json to a file
let mut network = load; // Load the model from a json file
println!;
Speed
The focus of this library is multi-threaded performance. The library is designed to be as fast as possible, and I have done my best to optimize the code for performance. The library is still in its early stages, so there is still room for improvement, but I have done my best to make it as fast as possible. I just wish Rust had a better implementation of Generics in Constant Expressions like C++.
Matrix parallelization is currently not implemented, but it will be once better generics are implemented in Rust.
Contributing
Contributions are highly encouraged! If you're interested in adding new features, improving performance, fixing bugs, or enhancing documentation, I would appreciate your help. Just open a pull request and I'll look into it.
Roadmap
The following features and might be implemented in a future releases:
- Support for more activation functions
- GPU acceleration using CUDA or similar technologies (probably just shaders but idk it seems hard)
- Enhanced model evaluation tools (and possibly, maybe a GUI to go with them. If I write one it will be in raylib btw)