Hextral
A comprehensive neural network library for Rust with modern features including batch normalization, multiple loss functions, advanced activation functions, and flexible architecture design.
Features
Core Architecture
- Multi-layer perceptrons with configurable hidden layers
- Batch normalization for improved training stability and convergence
- Xavier weight initialization for stable gradient flow
- Flexible network topology - specify any number of hidden layers and neurons
Activation Functions (9 Available)
- ReLU - Rectified Linear Unit (good for most cases)
- Sigmoid - Smooth activation for binary classification
- Tanh - Hyperbolic tangent for centered outputs
- Leaky ReLU - Prevents dying ReLU problem
- ELU - Exponential Linear Unit for smoother gradients
- Linear - For regression output layers
- Swish - Modern activation with smooth derivatives
- GELU - Gaussian Error Linear Unit used in transformers
- Mish - Self-regularizing activation function
Loss Functions (5 Available)
- Mean Squared Error (MSE) - Standard regression loss
- Mean Absolute Error (MAE) - Robust to outliers
- Binary Cross-Entropy - Binary classification
- Categorical Cross-Entropy - Multi-class classification
- Huber Loss - Robust hybrid of MSE and MAE
Optimization Algorithms
- Adam - Adaptive moment estimation (recommended for most cases)
- SGD - Stochastic Gradient Descent (simple and reliable)
- SGD with Momentum - Accelerated gradient descent
Regularization Techniques
- L2 Regularization - Prevents overfitting by penalizing large weights
- L1 Regularization - Encourages sparse networks and feature selection
- Dropout - Randomly deactivates neurons during training
Training & Evaluation
- Flexible loss computation with configurable loss functions
- Batch normalization with training/inference modes
- Training progress tracking with loss history
- Batch and single-sample prediction
- Model evaluation metrics and loss computation
Quick Start
Add this to your Cargo.toml:
[]
= "0.5.1"
= "0.33"
Basic Usage
use ;
use DVector;
Loss Functions
Configure different loss functions for your specific task:
use ;
let mut nn = new;
// For regression tasks
nn.set_loss_function;
nn.set_loss_function;
nn.set_loss_function;
// For classification tasks
nn.set_loss_function;
nn.set_loss_function;
Batch Normalization
Enable batch normalization for improved training stability:
use ;
let mut nn = new;
// Enable batch normalization
nn.enable_batch_norm;
// Set training mode
nn.set_training_mode;
// Train your network...
let loss_history = nn.train;
// Switch to inference mode
nn.set_training_mode;
// Make predictions...
let prediction = nn.predict;
Modern Activation Functions
Use state-of-the-art activation functions:
use ;
// Swish activation (used in EfficientNet)
let mut nn = new;
// GELU activation (used in BERT, GPT)
let mut nn = new;
// Mish activation (self-regularizing)
let mut nn = new;
Regularization
Prevent overfitting with built-in regularization techniques:
use ;
let mut nn = new;
// L2 regularization (Ridge)
nn.set_regularization;
// L1 regularization (Lasso)
nn.set_regularization;
// Dropout regularization
nn.set_regularization;
Different Optimizers
Choose the optimizer that works best for your problem:
// Adam: Good default choice, adaptive learning rates
let optimizer = Adam ;
// SGD: Simple and interpretable
let optimizer = SGD ;
// SGD with Momentum: Accelerated convergence
let optimizer = SGDMomentum ;
Network Introspection
Get insights into your network:
// Network architecture
println!; // [2, 4, 3, 1]
// Parameter count
println!; // 25
// Save/load weights
let weights = nn.get_weights;
nn.set_weights;
API Reference
## API Reference
### Core Types
- **`Hextral`** - Main neural network struct
- **`ActivationFunction`** - Enum for activation functions
- **`Optimizer`** - Enum for optimization algorithms
- **`Regularization`** - Enum for regularization techniques
### Key Methods
- **`new()`** - Create a new neural network
- **`train()`** - Train the network for multiple epochs
- **`predict()`** - Make a single prediction
- **`evaluate()`** - Compute loss on a dataset
- **`set_regularization()`** - Configure regularization
## Performance Tips
1. **Use ReLU activation** for hidden layers in most cases
2. **Start with Adam optimizer** - it adapts learning rates automatically
3. **Apply L2 regularization** if you see overfitting (test loss > train loss)
4. **Use dropout for large networks** to prevent co-adaptation
5. **Normalize your input data** to [0,1] or [-1,1] range for better training stability
## Architecture Decisions
- **Built on nalgebra** for efficient linear algebra operations
- **Xavier initialization** for stable gradient flow from the start
- **Proper error handling** throughout the API
- **Modular design** allowing easy extension of activation functions and optimizers
- **Zero-copy predictions** where possible for performance
## Contributing
We welcome contributions! Please feel free to:
- Report bugs by opening an issue
- Suggest new features or improvements
- Submit pull requests with enhancements
- Improve documentation
- Add more test cases
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Changelog
### v0.5.1 (Latest)
- **Improved Documentation**: Enhanced README with comprehensive examples of all new features
- **Better Crates.io Presentation**: Updated documentation to properly showcase library capabilities
### v0.5.0
- **Major Feature Expansion**: Added comprehensive loss functions, batch normalization, and modern activation functions
- **5 Loss Functions**: MSE, MAE, Binary Cross-Entropy, Categorical Cross-Entropy, Huber Loss
- **Batch Normalization**: Full implementation with training/inference modes
- **3 New Activation Functions**: Swish, GELU, Mish (total of 9 activation functions)
- **Code Organization**: Separated tests into dedicated files for cleaner library structure
- **Enhanced API**: Flexible loss function configuration and batch normalization controls
### v0.4.0 (Previous)
- **Complete rewrite** with proper error handling and fixed implementations
- **Implemented all documented features** - train(), predict(), evaluate() methods
- **Fixed critical bugs** in batch normalization and backward pass
- **Added regularization support** - L1, L2, and Dropout
- **Improved documentation** with usage examples and API reference