# rust-imbalanced-learn
**High-performance resampling techniques for imbalanced datasets in Rust**
[](https://github.com/timarocks/rust-imbalanced-learn/actions)
[](https://crates.io/crates/imbalanced-learn)
[](https://docs.rs/imbalanced-learn)
A Rust implementation of resampling techniques for handling imbalanced datasets in machine learning. This library provides a high-performance alternative to Python's imbalanced-learn, designed for the Rust ML community with performance, safety, and modern Rust idioms in mind.
## Features
- Zero-cost abstractions with compile-time state checking
- SIMD-accelerated algorithms for maximum performance
- Parallel processing using Rayon for multi-core utilization
- Memory safety with Rust's ownership system
- GPU acceleration on Apple Silicon via Metal (optional)
- Comprehensive metrics for model evaluation
- Seamless integration with the Rust ML ecosystem
## Supported Algorithms
### Resampling Techniques
- **SMOTE** (Synthetic Minority Over-sampling Technique)
- **ADASYN** (Adaptive Synthetic Sampling) - Coming Soon
- **Random Under/Over Sampling** - Coming Soon
- **BorderlineSMOTE** - Coming Soon
### Ensemble Methods
- **Balanced Random Forest** - In Development
- **EasyEnsemble** - Coming Soon
- **RUSBoost** - Coming Soon
### Evaluation Metrics
- **Classification Report** with per-class metrics
- **Confusion Matrix** with parallel computation
- **F1 Score** (macro, micro, weighted)
- **Balanced Accuracy**
- **Precision, Recall, Support**
## Quick Start
Add to your `Cargo.toml`:
```toml
[dependencies]
imbalanced-core = "0.1"
imbalanced-sampling = "0.1"
imbalanced-metrics = "0.1"
ndarray = "0.15"
```
### Basic Usage
```rust
use imbalanced_core::prelude::*;
use imbalanced_sampling::prelude::*;
use imbalanced_metrics::prelude::*;
use ndarray::{Array1, Array2};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load your imbalanced dataset
let (x, y) = load_imbalanced_data(); // Your data loading function
// Apply SMOTE resampling
let smote = SmoteStrategy::new(5); // k=5 neighbors
let config = SmoteConfig::default();
let (x_balanced, y_balanced) = smote.resample(x.view(), y.view(), &config)?;
// Train your model on balanced data
// ... (using linfa, smartcore, or candle)
// Evaluate with specialized metrics
let y_pred = model.predict(&x_test); // Your predictions
let report = classification_report(y_test.view(), y_pred.view());
println!("{}", report);
Ok(())
}
```
## Architecture
The library is organized as a workspace with focused crates:
```
rust-imbalanced-learn/
├── imbalanced-core/ # Core traits and abstractions
├── imbalanced-sampling/ # Resampling algorithms (SMOTE, etc.)
├── imbalanced-ensemble/ # Ensemble methods
└── imbalanced-metrics/ # Evaluation metrics
```
### Type-Safe State Management
```rust
use imbalanced_core::prelude::*;
// Compile-time state checking prevents misuse
let resampler = Resampler::new(SmoteStrategy::new(5)) // Uninitialized
.configure(config); // Configured
// Type-safe transitions ensure correct usage
```
## Integration with Rust ML Ecosystem
### Linfa Integration
```rust
use linfa::prelude::*;
use linfa_trees::DecisionTree;
let dataset = Dataset::new(x_balanced, y_balanced);
let model = DecisionTree::params()
.max_depth(Some(10))
.fit(&dataset)?;
```
### SmartCore Integration
```rust
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::tree::decision_tree_classifier::DecisionTreeClassifier;
let x_matrix = DenseMatrix::from_2d_array(&x_balanced.view().into_raw_vec());
let model = DecisionTreeClassifier::fit(&x_matrix, &y_balanced.view().to_slice().unwrap())?;
```
### Polars Integration
```rust
use polars::prelude::*;
let df = DataFrame::read_csv("imbalanced_data.csv")?;
let balanced_df = smote.resample_dataframe(&df, "target_column")?;
```
## Performance
Rust implementation provides significant speedups over Python:
| SMOTE | 10K samples | 15ms | 180ms | 12x |
| SMOTE | 100K samples | 120ms | 2.1s | 17.5x |
*Benchmarks run on M1 MacBook Pro with optimized release builds*
## Advanced Features
### SIMD Acceleration
```rust
let smote = SmoteStrategy::new(5)
.with_performance_hints(
PerformanceHints::new()
.with_hint(PerformanceHint::Vectorize)
.with_hint(PerformanceHint::Parallel)
);
```
### GPU Acceleration (Apple Silicon)
```toml
[dependencies]
imbalanced-core = { version = "0.1", features = ["metal-acceleration"] }
```
```rust
#[cfg(target_os = "macos")]
use imbalanced_core::platform::MetalKNN;
let knn = MetalKNN::new(device)?;
let neighbors = knn.find_neighbors(&x, &query, k).await?;
```
## Examples
Run the included examples:
```bash
# Basic SMOTE usage
cargo run --example basic_usage -p imbalanced-sampling
# Comprehensive pipeline
cargo run --example pipeline -p imbalanced-sampling
# Performance benchmarks
cargo run --example benchmarks -p imbalanced-sampling --release
```
## Development
### Building
```bash
git clone https://github.com/yourusername/rust-imbalanced-learn
cd rust-imbalanced-learn
cargo build --release
```
### Testing
```bash
cargo test --all
```
### Benchmarking
```bash
cargo bench --all
```
## Contributing
Contributions welcome! Please read our Contributing Guide and Code of Conduct.
### Priority Areas
- Additional resampling algorithms (ADASYN, BorderlineSMOTE)
- More ensemble methods (EasyEnsemble, RUSBoost)
- Advanced GPU acceleration
- Integration with more ML frameworks
- Performance optimizations
## Citation
If you use rust-imbalanced-learn in your research, please cite:
```bibtex
@software{rust_imbalanced_learn,
title={rust-imbalanced-learn: High-performance resampling for imbalanced datasets},
author={Tim},
year={2024},
url={https://github.com/yourusername/rust-imbalanced-learn}
}
```
## License
Licensed under either of
- Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE))
- MIT License ([LICENSE-MIT](LICENSE-MIT))
at your option.
## Acknowledgements
- [imbalanced-learn](https://imbalanced-learn.org/) - Original Python implementation
- [scikit-learn](https://scikit-learn.org/) - Machine learning fundamentals
- [Rust ML community](https://github.com/rust-ml) - Ecosystem foundation
---
**Built for the Rust ML community**