Hextral
A high-performance neural network library for Rust, featuring comprehensive support for multi-layer perceptrons with advanced optimization techniques, regularization methods, and flexible architecture design.
Features
Core Architecture
- Multi-layer perceptrons with configurable hidden layers
- Batch processing for efficient training on large datasets
- Xavier weight initialization for stable gradient flow
- Flexible network topology - specify any number of hidden layers and neurons
Activation Functions
- ReLU - Rectified Linear Unit (default, good for most cases)
- Sigmoid - Smooth activation for binary classification
- Tanh - Hyperbolic tangent for centered outputs
- Leaky ReLU - Prevents dying ReLU problem
- ELU - Exponential Linear Unit for smoother gradients
- Linear - For regression output layers
Optimization Algorithms
- Adam - Adaptive moment estimation (recommended for most cases)
- SGD - Stochastic Gradient Descent (simple and reliable)
- SGD with Momentum - Accelerated gradient descent
Regularization Techniques
- L2 Regularization - Prevents overfitting by penalizing large weights
- L1 Regularization - Encourages sparse networks and feature selection
- Dropout - Randomly deactivates neurons during training to prevent co-adaptation
Training & Evaluation
- Training progress tracking with loss history
- Batch and single-sample prediction
- Model evaluation metrics and loss computation
- Architecture introspection - query network structure and parameter count
Quick Start
Add this to your Cargo.toml:
[]
= "0.4.0"
= "0.33"
Basic Usage
use ;
use DVector;
Regularization
Prevent overfitting with built-in regularization techniques:
use ;
let mut nn = new;
// L2 regularization (Ridge)
nn.set_regularization;
// L1 regularization (Lasso)
nn.set_regularization;
// Dropout regularization
nn.set_regularization;
Different Optimizers
Choose the optimizer that works best for your problem:
// Adam: Good default choice, adaptive learning rates
let optimizer = Adam ;
// SGD: Simple and interpretable
let optimizer = SGD ;
// SGD with Momentum: Accelerated convergence
let optimizer = SGDMomentum ;
Network Introspection
Get insights into your network:
// Network architecture
println!; // [2, 4, 3, 1]
// Parameter count
println!; // 25
// Save/load weights
let weights = nn.get_weights;
nn.set_weights;
API Reference
## API Reference
### Core Types
- **`Hextral`** - Main neural network struct
- **`ActivationFunction`** - Enum for activation functions
- **`Optimizer`** - Enum for optimization algorithms
- **`Regularization`** - Enum for regularization techniques
### Key Methods
- **`new()`** - Create a new neural network
- **`train()`** - Train the network for multiple epochs
- **`predict()`** - Make a single prediction
- **`evaluate()`** - Compute loss on a dataset
- **`set_regularization()`** - Configure regularization
## Performance Tips
1. **Use ReLU activation** for hidden layers in most cases
2. **Start with Adam optimizer** - it adapts learning rates automatically
3. **Apply L2 regularization** if you see overfitting (test loss > train loss)
4. **Use dropout for large networks** to prevent co-adaptation
5. **Normalize your input data** to [0,1] or [-1,1] range for better training stability
## Architecture Decisions
- **Built on nalgebra** for efficient linear algebra operations
- **Xavier initialization** for stable gradient flow from the start
- **Proper error handling** throughout the API
- **Modular design** allowing easy extension of activation functions and optimizers
- **Zero-copy predictions** where possible for performance
## Contributing
We welcome contributions! Please feel free to:
- Report bugs by opening an issue
- Suggest new features or improvements
- Submit pull requests with enhancements
- Improve documentation
- Add more test cases
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Changelog
### v0.4.0 (Latest)
- **Complete rewrite** with proper error handling and fixed implementations
- **Implemented all documented features** - train(), predict(), evaluate() methods
- **Fixed critical bugs** in batch normalization and backward pass
- **Added regularization support** - L1, L2, and Dropout
- **Improved documentation** with usage examples and API reference
### v0.3.5 (Previous)
- Basic neural network functionality
- Limited optimizer support
- Minimal documentation