scry-learn
A scikit-learn-style ML toolkit in pure Rust. No Python runtime, no BLAS, no LAPACK — cargo add scry-learn and build.
use *;
let data = from_csv?;
let = train_test_split;
let mut rf = new
.n_estimators
.max_depth;
rf.fit?;
let preds = rf.predict?;
println!;
Status: 0.x, pre-1.0 — breaking changes are possible between minor versions. This started as a learning project (implement each algorithm from scratch to understand it) and grew into something usable, but it isn't a drop-in replacement for an established stack. If you need a settled Rust ML library today, linfa is the safer pick. The benchmarks below let you decide whether scry-learn fits your use case.
What's in the box
- Pure-Rust dependencies. No Python, no BLAS/LAPACK, no system libraries.
- TreeSHAP and permutation importance built in.
- Cross-library benchmarks vs linfa and smartcore, with a counting allocator, single-thread enforcement, and accuracy parity gates. Run them yourself.
- Column-major data layout — tree models scan features without a transpose.
#![deny(unsafe_code)].
Algorithms
| Model | Classification | Regression |
|---|---|---|
| Decision Tree (CART) | ✓ | ✓ |
| Random Forest | ✓ | ✓ |
| Gradient Boosting | ✓ | ✓ |
| Histogram Gradient Boosting | ✓ | ✓ |
| Linear / Logistic Regression | ✓ | ✓ |
| Ridge | — | ✓ |
| Lasso | — | ✓ |
| ElasticNet | — | ✓ |
| Linear SVM | ✓ | ✓ |
| Kernel SVM | ✓* | ✓* |
| K-Nearest Neighbors | ✓ | ✓ |
| Gaussian Naive Bayes | ✓ | — |
| Multinomial Naive Bayes | ✓ | — |
| Bernoulli Naive Bayes | ✓ | — |
| MLP Neural Network | ✓ | ✓ |
* Kernel SVM requires features = ["experimental"]
| Algorithm | Notes |
|---|---|
| K-Means | k-means++ init, configurable max_iter |
| Mini-Batch K-Means | Streaming-friendly variant |
| DBSCAN | Density-based, automatic cluster count |
| HDBSCAN | Hierarchical density-based |
| Agglomerative | Ward / complete / average / single linkage |
- Scaling: StandardScaler, MinMaxScaler, RobustScaler, Normalizer (L1/L2)
- Encoding: OneHotEncoder, LabelEncoder
- Imputation: SimpleImputer (mean, median, most-frequent, constant)
- Dimensionality: PCA, VarianceThreshold, SelectKBest (f_classif)
- Transforms: PolynomialFeatures, ColumnTransformer, Pipeline
- Search: GridSearchCV, RandomizedSearchCV, BayesSearchCV
- Validation: cross_val_score, stratified k-fold, group k-fold, time series split, repeated CV
- Classification metrics: accuracy, precision, recall, F1, balanced accuracy, Cohen's kappa, confusion matrix, ROC AUC, PR curve, log loss
- Regression metrics: MSE, MAPE, R², explained variance
- Clustering metrics: silhouette score, Calinski-Harabasz, Davies-Bouldin, adjusted Rand index
- Calibration: Platt scaling, isotonic regression
- TreeSHAP — exact Shapley values for tree ensembles in polynomial time (Lundberg & Lee, 2018)
- Permutation importance — model-agnostic feature importance with configurable repeats (Breiman, 2001)
use *;
let shap_values = ensemble_tree_shap;
let importance = permutation_importance;
- CountVectorizer — n-gram term counts, min/max document frequency, sparse CSR output
- TfidfVectorizer — TF-IDF with L1/L2 normalization, sublinear TF, smooth IDF
- Tokenizer — zero-dependency whitespace/punctuation-aware tokenizer
- Isolation Forest — unsupervised anomaly detection via random partitioning
sklearn → scry-learn
The API closely tracks scikit-learn.
| scikit-learn (Python) | scry-learn (Rust) |
|---|---|
from sklearn.ensemble import RandomForestClassifier |
use scry_learn::prelude::*; |
rf = RandomForestClassifier(n_estimators=100) |
let mut rf = RandomForestClassifier::new().n_estimators(100); |
rf.fit(X_train, y_train) |
rf.fit(&train)?; |
rf.predict(X_test) |
rf.predict(&test)? |
cross_val_score(rf, X, y, cv=5) |
cross_val_score(&rf, &data, 5, accuracy) |
GridSearchCV(rf, param_grid, cv=5) |
GridSearchCV::new(rf, param_grid, 5, accuracy) |
shap.TreeExplainer(rf).shap_values(X) |
ensemble_tree_shap(&rf, &features) |
StandardScaler().fit_transform(X) |
StandardScaler::new().fit_transform(&mut data)? |
Benchmarks
Cross-library benchmarks against linfa and smartcore. The harness enforces:
- Real UCI datasets only — no synthetic data with RNG bias
- Counting allocator — actual heap bytes, not RSS estimates
- Single-thread execution, asserted programmatically (not assumed via env var)
- Accuracy parity gates — timing only reported when all libraries converge within ε=3%
- Identical preprocessing across libraries
Run them yourself:
# Extended scaling curves (500 / 2K / 10K samples)
Install
[]
= "0.1"
Optional features
| Feature | What it enables |
|---|---|
csv |
Dataset::from_csv() file loading |
serde |
Serialize / deserialize models |
polars |
Polars DataFrame interop |
mmap |
Memory-mapped dataset loading |
experimental |
Kernel SVM (RBF, polynomial kernels) |
= { = "0.1", = ["csv", "serde"] }
Examples
# 5-fold stratified CV across 8 models on 4 UCI datasets
# Head-to-head comparison vs linfa and smartcore
Test suite
843 tests across 24 test files cover correctness, convergence, numerical stability, and cross-library parity.
| Test suite | What it validates |
|---|---|
correctness |
sklearn reference accuracy verification |
convergence |
Monotonic improvement and max_iter stability |
numerical_stability |
NaN/Inf handling, gradient norm tracking |
mathematical_invariants |
SHAP additivity (Σφᵢ = pred − E[f(x)]) |
golden_regression_test |
Deterministic snapshot tests |
statistical_robustness |
Bootstrap confidence interval validity |
edge_cases |
Empty datasets, single samples, NaN/Inf inputs |
production_bench |
Heap memory, allocation counts, scaling curves |
memory_crosslib |
Heap usage comparison across libraries |
Contributing
Issues and PRs welcome. Please open an issue before large changes.
License
MIT OR Apache-2.0