# Hextral
A high-performance neural network library for Rust with clean async-first API, advanced activation functions, multiple optimizers, early stopping, and checkpointing capabilities.
[](https://crates.io/crates/hextral)
[](https://docs.rs/hextral)
## Features
### **Core Architecture**
- **Multi-layer perceptrons** with configurable hidden layers
- **Batch normalization** for improved training stability and convergence
- **Xavier weight initialization** for stable gradient flow
- **Flexible network topology** - specify any number of hidden layers and neurons
- **Clean async-first API** with intelligent yielding for non-blocking operations
### **Activation Functions (9 Available)**
- **ReLU** - Rectified Linear Unit (good for most cases)
- **Sigmoid** - Smooth activation for binary classification
- **Tanh** - Hyperbolic tangent for centered outputs
- **Leaky ReLU** - Prevents dying ReLU problem
- **ELU** - Exponential Linear Unit for smoother gradients
- **Linear** - For regression output layers
- **Swish** - Modern activation with smooth derivatives
- **GELU** - Gaussian Error Linear Unit used in transformers
- **Mish** - Self-regularizing activation function
- **Quaternion** - Quaternion-based normalization for 4D data
### **Loss Functions (5 Available)**
- **Mean Squared Error (MSE)** - Standard regression loss
- **Mean Absolute Error (MAE)** - Robust to outliers
- **Binary Cross-Entropy** - Binary classification
- **Categorical Cross-Entropy** - Multi-class classification
- **Huber Loss** - Robust hybrid of MSE and MAE
### **Optimization Algorithms (12 Available)**
- **Adam** - Adaptive moment estimation (recommended for most cases)
- **AdamW** - Adam with decoupled weight decay
- **NAdam** - Nesterov-accelerated Adam
- **AdaBelief** - Adapting stepsizes by belief in observed gradients
- **Lion** - Evolved sign momentum optimizer
- **SGD** - Stochastic Gradient Descent (simple and reliable)
- **SGD with Momentum** - Accelerated gradient descent
- **RMSprop** - Root mean square propagation
- **AdaGrad** - Adaptive gradient algorithm
- **AdaDelta** - Extension of AdaGrad
- **LBFGS** - Limited-memory BFGS (quasi-Newton method)
- **Ranger** - Combination of RAdam and LookAhead
### **Advanced Training Features**
- **Early Stopping** - Automatic training termination based on validation loss
- **Checkpointing** - Save and restore model weights with bincode serialization
- **Regularization** - L1/L2 regularization and dropout support
- **Batch Training** - Configurable batch sizes for memory efficiency
- **Training Progress Tracking** - Loss history and validation monitoring
- **Dual sync/async API** for both blocking and non-blocking operations
### **Async/Concurrent Processing**
- **Async training methods** with cooperative multitasking
- **Parallel batch prediction** using futures
- **Intelligent yielding** - only yields for large workloads (>1000 elements)
- **Concurrent activation function processing**
- **Performance-optimized** async implementation alongside synchronous methods
## Quick Start
Add this to your `Cargo.toml`:
```toml
[dependencies]
hextral = "0.7.0"
nalgebra = "0.33"
tokio = { version = "1.0", features = ["full"] } # For async features
```
### Basic Async Usage (Recommended)
```rust
use hextral::{Hextral, ActivationFunction, Optimizer, EarlyStopping, CheckpointConfig};
use nalgebra::DVector;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create a neural network: 2 inputs -> [4, 3] hidden -> 1 output
let mut nn = Hextral::new(
2, // Input features
&[4, 3], // Hidden layer sizes
1, // Output size
ActivationFunction::ReLU, // Activation function
Optimizer::adam(0.01), // Modern Adam optimizer
);
// Training data for XOR problem
let train_inputs = vec![
DVector::from_vec(vec![0.0, 0.0]),
DVector::from_vec(vec![0.0, 1.0]),
DVector::from_vec(vec![1.0, 0.0]),
DVector::from_vec(vec![1.0, 1.0]),
];
let train_targets = vec![
DVector::from_vec(vec![0.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![0.0]),
];
// Validation data (can be same as training for demo)
let val_inputs = train_inputs.clone();
let val_targets = train_targets.clone();
// Configure early stopping and checkpointing
let early_stopping = EarlyStopping::new(10, 0.001, true);
let checkpoint_config = CheckpointConfig::new("best_model".to_string());
// Train the network with advanced features
println!("Training network with early stopping...");
let (train_history, val_history) = nn.train(
&train_inputs,
&train_targets,
0.1, // Learning rate
1000, // Max epochs
Some(2), // Batch size
Some(&val_inputs), // Validation inputs
Some(&val_targets), // Validation targets
Some(early_stopping), // Early stopping
Some(checkpoint_config), // Checkpointing
).await?;
println!("Training completed after {} epochs", train_history.len());
println!("Final validation loss: {:.6}", val_history.last().unwrap_or(&0.0));
// Make predictions
println!("\nPredictions:");
for (input, expected) in train_inputs.iter().zip(train_targets.iter()) {
let prediction = nn.predict(input).await;
println!("Input: {:?} | Expected: {:.1} | Predicted: {:.3}",
input.data.as_vec(), expected[0], prediction[0]);
}
// Batch prediction (efficient for multiple inputs)
let batch_predictions = nn.predict_batch(&train_inputs).await;
// Evaluate performance
let final_loss = nn.evaluate(&train_inputs, &train_targets).await;
println!("Final loss: {:.6}", final_loss);
Ok(())
}
```
### Advanced Features
```rust
use hextral::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create network with advanced activation function
let mut nn = Hextral::new(
4, &[8, 6], 2,
ActivationFunction::Swish { beta: 1.0 }, // Modern Swish activation
Optimizer::adamw(0.001, 0.01), // AdamW with weight decay
);
// Enable batch normalization for better training stability
nn.enable_batch_norm();
nn.set_training_mode(true);
// Configure regularization
nn.set_regularization(Regularization::L2(0.001));
let inputs = vec![/* your training data */];
let targets = vec![/* your target data */];
// Advanced training with all features
let early_stop = EarlyStopping::new(
15, // Patience: stop if no improvement for 15 epochs
0.0001, // Minimum improvement threshold
true, // Restore best weights when stopping
);
let checkpoint = CheckpointConfig::new("model_checkpoint".to_string())
.save_every(10); // Save every 10 epochs
let (train_losses, val_losses) = nn.train(
&inputs, &targets,
0.01, // Learning rate
500, // Max epochs
Some(32), // Batch size
Some(&inputs), // Validation inputs
Some(&targets), // Validation targets
Some(early_stop), // Early stopping
Some(checkpoint), // Checkpointing
).await?;
// Switch to inference mode
nn.set_training_mode(false);
Ok(())
}
```
- **Scalable architecture** - Ideal for web services and concurrent applications
- **Parallel batch processing** - Multiple predictions processed concurrently using futures
### Loss Functions
Configure different loss functions for your specific task:
```rust
use hextral::{Hextral, LossFunction, ActivationFunction, Optimizer};
let mut nn = Hextral::new(2, &[4], 1, ActivationFunction::ReLU, Optimizer::default());
// For regression tasks
nn.set_loss_function(LossFunction::MeanSquaredError);
nn.set_loss_function(LossFunction::MeanAbsoluteError);
nn.set_loss_function(LossFunction::Huber { delta: 1.0 });
// For classification tasks
nn.set_loss_function(LossFunction::BinaryCrossEntropy);
nn.set_loss_function(LossFunction::CategoricalCrossEntropy);
```
### Batch Normalization
Enable batch normalization for improved training stability:
```rust
use hextral::{Hextral, ActivationFunction, Optimizer};
let mut nn = Hextral::new(10, &[64, 32], 1, ActivationFunction::ReLU, Optimizer::default());
// Enable batch normalization
nn.enable_batch_norm();
// Set training mode
nn.set_training_mode(true);
// Train your network...
let loss_history = nn.train(&inputs, &targets, 0.01, 100);
// Switch to inference mode
nn.set_training_mode(false);
// Make predictions...
let prediction = nn.predict(&input);
```
### Modern Activation Functions
Use state-of-the-art activation functions:
```rust
use hextral::{Hextral, ActivationFunction, Optimizer};
// Swish activation (used in EfficientNet)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Swish { beta: 1.0 }, Optimizer::default());
// GELU activation (used in BERT, GPT)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::GELU, Optimizer::default());
// Mish activation (self-regularizing)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Mish, Optimizer::default());
```
### Regularization
Prevent overfitting with built-in regularization techniques:
```rust
use hextral::{Hextral, Regularization, ActivationFunction, Optimizer};
let mut nn = Hextral::new(3, &[16, 8], 1, ActivationFunction::ReLU,
Optimizer::Adam { learning_rate: 0.01 });
// L2 regularization (Ridge)
nn.set_regularization(Regularization::L2(0.01));
// L1 regularization (Lasso)
nn.set_regularization(Regularization::L1(0.005));
// Dropout regularization
nn.set_regularization(Regularization::Dropout(0.3));
```
### Different Optimizers
Choose the optimizer that works best for your problem:
```rust
// Adam: Good default choice, adaptive learning rates
let optimizer = Optimizer::Adam { learning_rate: 0.001 };
// SGD: Simple and interpretable
let optimizer = Optimizer::SGD { learning_rate: 0.1 };
// SGD with Momentum: Accelerated convergence
let optimizer = Optimizer::SGDMomentum {
learning_rate: 0.1,
momentum: 0.9
};
```
### Network Introspection
Get insights into your network:
```rust
// Network architecture
println!("Architecture: {:?}", nn.architecture()); // [2, 4, 3, 1]
// Parameter count
println!("Total parameters: {}", nn.parameter_count()); // 25
// Save/load weights
let weights = nn.get_weights();
nn.set_weights(weights);
```
## API Reference
### Core Types
- **`Hextral`** - Main neural network struct with async-first API
- **`ActivationFunction`** - Enum for activation functions (9 available)
- **`Optimizer`** - Enum for optimization algorithms (12 available)
- **`Regularization`** - Enum for regularization techniques
- **`EarlyStopping`** - Configuration for automatic training termination
- **`CheckpointConfig`** - Configuration for model checkpointing
- **`LossFunction`** - Enum for loss functions (5 available)
### Primary Methods (All Async)
```rust
// Network creation
Hextral::new(inputs, hidden_layers, outputs, activation, optimizer) -> Hextral
// Training with full feature set
async fn train(
&mut self,
train_inputs: &[DVector<f64>],
train_targets: &[DVector<f64>],
learning_rate: f64,
epochs: usize,
batch_size: Option<usize>,
val_inputs: Option<&[DVector<f64>]>,
val_targets: Option<&[DVector<f64>]>,
early_stopping: Option<EarlyStopping>,
checkpoint_config: Option<CheckpointConfig>,
) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>>
// Predictions
async fn predict(&self, input: &DVector<f64>) -> DVector<f64>
async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>>
// Evaluation
async fn evaluate(&self, inputs: &[DVector<f64>], targets: &[DVector<f64>]) -> f64
// Forward pass
async fn forward(&self, input: &DVector<f64>) -> DVector<f64>
```
### Configuration Methods
```rust
// Batch normalization
fn enable_batch_norm(&mut self)
fn disable_batch_norm(&mut self)
fn set_training_mode(&mut self, training: bool)
// Regularization
fn set_regularization(&mut self, reg: Regularization)
// Loss function
fn set_loss_function(&mut self, loss: LossFunction)
// Weight management
fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)>
fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>)
fn parameter_count(&self) -> usize
```
### Early Stopping & Checkpointing
```rust
// Early stopping configuration
let early_stop = EarlyStopping::new(
patience: usize, // Epochs to wait for improvement
min_delta: f64, // Minimum improvement threshold
restore_best_weights: bool // Whether to restore best weights
);
// Checkpoint configuration
let checkpoint = CheckpointConfig::new("model_path".to_string())
.save_every(10) // Save every N epochs
.save_best(true); // Save best model based on validation loss
```
## Performance Tips
1. **Use ReLU activation** for hidden layers in most cases
2. **Start with Adam optimizer** - it adapts learning rates automatically
3. **Apply L2 regularization** if you see overfitting (test loss > train loss)
4. **Use dropout for large networks** to prevent co-adaptation
5. **Normalize your input data** to [0,1] or [-1,1] range for better training stability
## Architecture Decisions
- **Built on nalgebra** for efficient linear algebra operations
- **Xavier initialization** for stable gradient flow from the start
- **Proper error handling** throughout the API
- **Modular design** allowing easy extension of activation functions and optimizers
- **Zero-copy predictions** where possible for performance
## Contributing
We welcome contributions! Please feel free to:
- Report bugs by opening an issue
- Suggest new features or improvements
- Submit pull requests with enhancements
- Improve documentation
- Add more test cases
## Changelog
## Changelog
### v0.7.0 (Latest)
- **Removed Redundancy**: Eliminated confusing duplicate methods and verbose naming patterns
- **Better Performance**: Streamlined async implementation with intelligent yielding
- **Updated Documentation**: All examples now use clean, consistent API
- **All Tests Updated**: Comprehensive test suite updated for new API patterns
### v0.6.0
- **Full Async/Await Support**: Complete async API alongside synchronous methods
- **Intelligent Yielding**: Performance-optimized async with yielding only for large workloads (>1000 elements)
- **Concurrent Processing**: Parallel batch predictions using futures and join_all
- **Async Training**: Non-blocking training with cooperative multitasking
- **Code Optimization**: Removed verbose AI-generated patterns, cleaner professional code
- **Performance Improvements**: Smart async yielding prevents unnecessary overhead
- **Enhanced Documentation**: Updated examples and API documentation
### v0.5.1
- **Improved Documentation**: Enhanced README with comprehensive examples of all new features
- **Better Crates.io Presentation**: Updated documentation to properly showcase library capabilities
### v0.5.0
- **Major Feature Expansion**: Added comprehensive loss functions, batch normalization, and modern activation functions
- **5 Loss Functions**: MSE, MAE, Binary Cross-Entropy, Categorical Cross-Entropy, Huber Loss
- **Batch Normalization**: Full implementation with training/inference modes
- **3 New Activation Functions**: Swish, GELU, Mish (total of 9 activation functions)
- **Code Organization**: Separated tests into dedicated files for cleaner library structure
- **Enhanced API**: Flexible loss function configuration and batch normalization controls
### v0.4.0
- **Complete rewrite** with proper error handling and fixed implementations
- **Implemented all documented features** - train(), predict(), evaluate() methods
- **Fixed critical bugs** in batch normalization and backward pass
- **Added regularization support** - L1, L2, and Dropout
- **Improved documentation** with usage examples and API reference
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
## License
This project is licensed under the MIT OR Apache-2.0 license.