tensorlogic-train
Training scaffolds for Tensorlogic: loss composition, optimizers, schedulers, and callbacks.
Overview
tensorlogic-train provides comprehensive training infrastructure for Tensorlogic models, combining standard ML training components with logic-specific loss functions for constraint satisfaction and rule adherence.
Features
Loss Functions (15 types)
- Standard Losses: Cross-entropy, MSE, BCE with logits
- Robust Losses: Focal (class imbalance), Huber (outliers)
- Segmentation: Dice, Tversky (IoU-based losses)
- Metric Learning: Contrastive, Triplet (embedding learning)
- Classification: Hinge (SVM-style max-margin)
- Distribution: KL Divergence (distribution matching)
- Advanced: Poly Loss (polynomial expansion of CE for better generalization)
- Logical Losses: Rule satisfaction, constraint violation penalties
- Multi-objective: Weighted combination of supervised + logical losses
- Gradient Computation: All losses support automatic gradient computation
Optimizers (18 types)
- SGD: Momentum support, gradient clipping (value and L2 norm)
- Adam: First/second moment estimation, bias correction
- AdamW: Decoupled weight decay for better regularization
- AdamP: Adam with projection for better generalization
- RMSprop: Adaptive learning rates with moving average
- Adagrad: Accumulating gradient normalization
- NAdam: Nesterov-accelerated Adam
- LAMB: Layer-wise adaptive moments (large-batch training)
- AdaMax: Adam variant with infinity norm (robust to large gradients)
- Lookahead: Slow/fast weights for improved convergence
- AdaBelief (NeurIPS 2020): Adapts stepsizes by gradient belief
- RAdam (ICLR 2020): Rectified Adam with variance warmup
- LARS: Layer-wise adaptive rate scaling for large batch training
- SAM (ICLR 2021): Sharpness aware minimization for better generalization
- Lion: Modern memory-efficient optimizer with sign-based updates (EvoLved Sign Momentum)
- Prodigy (2024): Auto-tuning learning rate, eliminates manual LR tuning entirely
- ScheduleFreeAdamW (2024): Eliminates LR scheduling via parameter averaging (Defazio et al.)
- Sophia: Second-order optimizer with Hessian diagonal estimates
- Gradient Clipping: By value (element-wise) or by L2 norm (global)
- Gradient Centralization: Drop-in optimizer wrapper (LayerWise, Global, PerRow, PerColumn)
- State Management: Save/load optimizer state for checkpointing
Learning Rate Schedulers (12 types)
- StepLR: Step decay every N epochs
- ExponentialLR: Exponential decay per epoch
- CosineAnnealingLR: Cosine annealing with warmup
- WarmupScheduler: Linear learning rate warmup
- OneCycleLR: Super-convergence with single cycle
- PolynomialDecayLR: Polynomial learning rate decay
- CyclicLR: Triangular/exponential cyclic schedules
- WarmupCosineLR: Warmup + cosine annealing
- NoamScheduler (Transformer): Attention is All You Need schedule
- MultiStepLR: Decay at specific milestone epochs
- ReduceLROnPlateau: Adaptive reduction based on validation metrics
- SGDR: Stochastic Gradient Descent with Warm Restarts
Batch Management
- BatchIterator: Configurable batch iteration with shuffling
- DataShuffler: Deterministic shuffling with seed control
- Flexible Configuration: Drop last, custom batch sizes
Training Loop
- Trainer: Complete training orchestration
- Epoch/Batch Iteration: Automated iteration with state tracking
- Validation: Built-in validation loop with metrics
- History Tracking: Loss and metrics history across epochs
Callbacks (14+ types)
- Training Events: on_train/epoch/batch/validation hooks
- EarlyStoppingCallback: Stop training when validation plateaus
- CheckpointCallback: Save model checkpoints (best/periodic) with optional gzip compression
- ReduceLrOnPlateauCallback: Adaptive learning rate reduction
- LearningRateFinder: Find optimal learning rate automatically
- GradientMonitor: Track gradient flow and detect issues
- HistogramCallback: Monitor weight distributions
- ProfilingCallback: Track training performance and throughput
- ModelEMACallback: Exponential moving average for stable predictions
- GradientAccumulationCallback: Simulate large batches with limited memory
- SWACallback: Stochastic Weight Averaging for better generalization
- MemoryProfilerCallback: Track memory usage during training
- Custom Callbacks: Easy-to-implement callback trait
Metrics (19+ types)
- Accuracy: Classification accuracy with argmax
- Precision/Recall: Per-class and macro-averaged
- F1 Score: Harmonic mean of precision/recall
- ConfusionMatrix: Full confusion matrix with per-class analysis
- ROC/AUC: ROC curve computation and AUC calculation
- PerClassMetrics: Comprehensive per-class reporting with pretty printing
- MetricTracker: Multi-metric tracking with history
- TopKAccuracy: Top-k classification accuracy
- NDCG: Normalized Discounted Cumulative Gain (ranking)
- BalancedAccuracy, CohensKappa, MatthewsCorrelationCoefficient
- Computer Vision: IoU, MeanIoU, DiceCoefficient, MeanAveragePrecision
- Calibration: ExpectedCalibrationError, MaximumCalibrationError
Model Interface
- Model Trait: Flexible interface for trainable models
- AutodiffModel: Integration point for automatic differentiation
- DynamicModel: Support for variable-sized inputs
- LinearModel: Reference implementation demonstrating the interface
Regularization (9 types)
- L1 Regularization: Lasso with sparsity-inducing penalties
- L2 Regularization: Ridge for weight decay
- Elastic Net: Combined L1+L2 regularization
- Composite: Combine multiple regularization strategies
- Spectral Normalization: GAN training stability
- MaxNorm Constraint: Gradient stability
- Orthogonal Regularization: W^T * W approximates identity
- Group Lasso: Group-wise sparsity
- Full Gradient Support: All regularizers compute gradients
Data Augmentation (9+ types)
- Noise Augmentation: Gaussian noise with Box-Muller transform
- Scale Augmentation: Random scaling within configurable ranges
- Rotation Augmentation: Placeholder for future image rotation
- Mixup: Zhang et al. (ICLR 2018) for improved generalization
- CutMix: CutMix augmentation technique
- RandomErasing: Randomly erase rectangular regions (AAAI 2020)
- CutOut: Fixed-size random square erasing
- Composite Pipeline: Chain multiple augmentations
- SciRS2 RNG: Uses SciRS2 for random number generation
Regularization (Advanced)
- DropPath / Stochastic Depth: Randomly drops entire residual paths (ECCV 2016)
- Linear and Exponential schedulers for depth-based drop probability
- Widely used in Vision Transformers (ViT, DeiT, Swin)
- DropBlock: Structured dropout for CNNs (NeurIPS 2018)
- Drops contiguous blocks instead of individual neurons
- Configurable block size with linear scheduler
Logging and Monitoring
- Console Logger: Stdout logging with timestamps
- File Logger: Persistent file logging with append/truncate modes
- TensorBoard Logger: Real tfevents format with CRC32, scalars, histograms, text
- CSV Logger: Machine-readable CSV for pandas/spreadsheet analysis
- JSONL Logger: JSON Lines format for programmatic processing
- Metrics Logger: Aggregates and logs to multiple backends
- Structured Logging: Optional tracing/tracing-subscriber integration (feature flag)
Memory Management
- MemoryProfilerCallback: Track memory usage with reports during training
- GradientCheckpointConfig: Strategies for memory-efficient training
- MemoryBudgetManager: Allocation tracking and budget enforcement
- MemoryEfficientTraining: Optimal batch size and model memory estimation
Data Loading and Preprocessing
- Dataset struct: Unified data container with features/targets
- Train/val/test splits: Configurable split ratios with validation
- CSV Loader: Configurable CSV data loading with column selection
- DataPreprocessor: Standardization, normalization, min-max scaling
- LabelEncoder: String to numeric label conversion
- OneHotEncoder: Categorical to binary encoding
Advanced Machine Learning
- Curriculum Learning: 5 strategies (Linear, Exponential, SelfPaced, Competence, Task)
- Transfer Learning: LayerFreezing, ProgressiveUnfreezing, DiscriminativeFineTuning
- Hyperparameter Optimization: Grid search, Random search, Bayesian Optimization (GP-based)
- Cross-Validation: KFold, StratifiedKFold, TimeSeriesSplit, LeaveOneOut
- Model Ensembling: Voting, Averaging, Stacking, Bagging
- Knowledge Distillation: Standard, Feature, Attention transfer
- Label Smoothing: Label smoothing and Mixup regularization
- Multi-task Learning: Fixed weights, DTP, PCGrad strategies
- Few-Shot Learning: Prototypical networks, Matching networks, N-way K-shot sampling
- Meta-Learning: MAML (Model-Agnostic Meta-Learning) and Reptile
- Model Pruning: Magnitude, Gradient, Structured, Global pruning with iterative schedules
- Model Quantization: INT8/INT4/INT2, Post-Training Quantization, Quantization-Aware Training
- Mixed Precision Training: FP16/BF16 support, loss scaling, master weights
Sampling Strategies (7 techniques)
- HardNegativeMiner: TopK, threshold, and focal strategies
- ImportanceSampler: With/without replacement
- FocalSampler: Emphasize hard examples
- ClassBalancedSampler: Handle class imbalance
- CurriculumSampler: Progressive difficulty
- OnlineHardExampleMiner: Dynamic batch selection
- BatchReweighter: Uniform, inverse loss, focal, gradient norm
Model Utilities
- ParameterStats / ModelSummary: Parameter analysis and model introspection
- GradientStats: Gradient monitoring and analysis
- TimeEstimator: Training time prediction
- LrRangeTestAnalyzer: Optimal LR finding analysis
- compare_models: Model comparison utilities
Installation
Add to your Cargo.toml:
[]
= { = "../tensorlogic-train" }
Quick Start
use ;
use Array2;
use HashMap;
// Create loss function
let loss = Boxnew;
// Create optimizer
let optimizer_config = OptimizerConfig ;
let optimizer = Boxnew;
// Create trainer
let config = TrainerConfig ;
let mut trainer = new;
// Add callbacks
let mut callbacks = new;
callbacks.add;
trainer = trainer.with_callbacks;
// Add metrics
let mut metrics = new;
metrics.add;
trainer = trainer.with_metrics;
// Prepare data
let train_data = zeros;
let train_targets = zeros;
let val_data = zeros;
let val_targets = zeros;
// Train model
let mut parameters = new;
parameters.insert;
let history = trainer.train.unwrap;
// Access training history
println!;
println!;
if let Some = history.best_val_loss
Logical Loss Functions
Combine supervised learning with logical constraints:
use ;
// Configure loss weights
let config = LossConfig ;
// Create logical loss
let logical_loss = new;
// Compute total loss
let total_loss = logical_loss.compute_total?;
Early Stopping
Stop training automatically when validation stops improving:
use ;
let mut callbacks = new;
callbacks.add;
trainer = trainer.with_callbacks;
// Training will stop automatically if validation doesn't improve for 5 epochs
Checkpointing
Save model checkpoints during training:
use ;
use PathBuf;
let mut callbacks = new;
callbacks.add;
trainer = trainer.with_callbacks;
Learning Rate Scheduling
Adjust learning rate during training:
use ;
let scheduler = Boxnew;
trainer = trainer.with_scheduler;
Gradient Clipping by Norm
Use L2 norm clipping for stable training of deep networks:
use ;
let optimizer = Boxnew;
// Global L2 norm is computed across all parameters:
// norm = sqrt(sum(g_i^2 for all gradients g_i))
// If norm > 5.0, all gradients are scaled by (5.0 / norm)
Enhanced Metrics
Confusion Matrix
use ConfusionMatrix;
let cm = compute?;
// Pretty print the confusion matrix
println!;
// Get per-class metrics
let precision = cm.precision_per_class;
let recall = cm.recall_per_class;
let f1 = cm.f1_per_class;
// Get overall accuracy
println!;
ROC Curve and AUC
use RocCurve;
// Binary classification example
let predictions = vec!;
let targets = vec!;
let roc = compute?;
// Compute AUC
println!;
Regularization
Prevent overfitting with L1, L2, or Elastic Net regularization:
use ;
use Array2;
use HashMap;
// Create L2 regularization (weight decay)
let regularizer = new; // lambda = 0.01
// Compute regularization penalty
let mut parameters = new;
parameters.insert;
let penalty = regularizer.compute_penalty?;
let gradients = regularizer.compute_gradient?;
// Add penalty to loss and gradients to parameter updates
total_loss += penalty;
Data Augmentation
Apply on-the-fly data augmentation during training:
use ;
use Array2;
// Gaussian noise augmentation
let noise_aug = new; // mean=0, std=0.1
let augmented = noise_aug.augment?;
// Mixup augmentation (Zhang et al., ICLR 2018)
let mixup = new; // alpha = 1.0 (uniform mixing)
let = mixup.mixup?;
Logging and Monitoring
Track training progress with multiple logging backends:
use ;
use PathBuf;
// Console logging with timestamps
let console = new; // with_timestamp = true
console.log_epoch?;
// File logging
let file_logger = new?;
file_logger.log_batch?;
// Aggregate metrics across backends
let mut metrics_logger = new;
metrics_logger.add_backend;
metrics_logger.add_backend;
// Log to all backends
metrics_logger.log_metric?;
metrics_logger.log_epoch?;
Architecture
Module Structure
tensorlogic-train/
├── src/
│ ├── lib.rs # Public API exports
│ ├── error.rs # Error types
│ ├── loss.rs # 15 loss functions
│ ├── optimizer.rs # Re-exports all optimizers
│ ├── optimizers/ # 18 optimizer implementations
│ │ ├── sgd.rs, adam.rs, adamw.rs, adamp.rs
│ │ ├── rmsprop.rs, adagrad.rs, nadam.rs
│ │ ├── lamb.rs, adamax.rs, lookahead.rs
│ │ ├── adabelief.rs, radam.rs, lars.rs
│ │ ├── sam.rs, lion.rs, prodigy.rs
│ │ ├── schedulefree.rs, sophia.rs
│ │ └── common.rs
│ ├── scheduler.rs # 12 LR schedulers
│ ├── batch.rs # Batch management
│ ├── trainer.rs # Main training loop
│ ├── callbacks/ # 14+ training callbacks
│ │ ├── core.rs, advanced.rs, checkpoint.rs
│ │ ├── early_stopping.rs, gradient.rs
│ │ ├── histogram.rs, lr_finder.rs, profiling.rs
│ │ └── mod.rs
│ ├── metrics/ # 19+ evaluation metrics (7 modules)
│ │ ├── basic.rs, advanced.rs, ranking.rs
│ │ ├── vision.rs, calibration.rs, tracker.rs
│ │ └── mod.rs
│ ├── model.rs # Model trait interface
│ ├── regularization.rs # 9 regularization techniques
│ ├── augmentation.rs # 9 data augmentation types
│ ├── stochastic_depth.rs # DropPath / Stochastic Depth
│ ├── dropblock.rs # DropBlock structured dropout
│ ├── logging.rs # 6 logging backends
│ ├── memory.rs # Memory management
│ ├── data.rs # Data loading and preprocessing
│ ├── curriculum.rs # 5 curriculum learning strategies
│ ├── transfer.rs # Transfer learning utilities
│ ├── hyperparameter.rs # Grid, Random, Bayesian search
│ ├── crossval.rs # 4 cross-validation strategies
│ ├── ensemble.rs # 4 ensembling methods
│ ├── distillation.rs # Knowledge distillation
│ ├── label_smoothing.rs # Label smoothing and Mixup
│ ├── multitask.rs # Multi-task learning
│ ├── pruning.rs # Model pruning
│ ├── sampling.rs # 7 advanced sampling strategies
│ ├── quantization.rs # INT8/INT4/INT2 quantization
│ ├── mixed_precision.rs # FP16/BF16 mixed precision
│ ├── few_shot.rs # Prototypical/Matching networks
│ ├── meta_learning.rs # MAML and Reptile
│ ├── gradient_centralization.rs# Gradient centralization
│ ├── structured_logging.rs # tracing integration (feature-gated)
│ └── utils.rs # Model introspection utilities
Key Traits
Model: Forward/backward passes and parameter managementAutodiffModel: Automatic differentiation integration (trait extension)DynamicModel: Variable-sized input supportLoss: Compute loss and gradientsOptimizer: Update parameters with gradientsLrScheduler: Adjust learning rateCallback: Hook into training eventsMetric: Evaluate model performanceRegularizer: Compute regularization penalties and gradientsDataAugmenter: Apply data transformationsLoggingBackend: Log training metrics and events
Integration with SciRS2
This crate strictly follows the SciRS2 integration policy:
// Correct: Use SciRS2 types
use ;
// Wrong: Never use these directly
// use ndarray::Array2; // Never!
// use rand::thread_rng; // Never!
All tensor operations use scirs2_core::ndarray, ready for seamless integration with scirs2-autograd for automatic differentiation.
Test Coverage
All modules have comprehensive unit tests:
| Module | Tests | Coverage |
|---|---|---|
loss.rs |
15 | All 15 loss functions |
optimizer.rs / optimizers/ |
79 | All 18 optimizers + ScheduleFreeAdamW, Prodigy, Sophia |
scheduler.rs |
13 | All 12 schedulers |
batch.rs |
5 | Batch iteration and sampling |
trainer.rs |
3 | Training loop |
callbacks/ |
28 | All 14+ callbacks |
metrics/ |
34 | All 19+ metrics (refactored into 7 modules) |
model.rs |
6 | Model interface and implementations |
regularization.rs |
16 | L1, L2, ElasticNet, Composite, Spectral Norm, MaxNorm, Orthogonal, Group Lasso |
augmentation.rs |
25 | Noise, Scale, Rotation, Mixup, CutMix, RandomErasing, CutOut |
stochastic_depth.rs |
14 | DropPath, Linear/Exponential schedulers |
dropblock.rs |
12 | DropBlock structured dropout |
logging.rs |
15 | Console, File, TensorBoard, CSV, JSONL, MetricsLogger |
memory.rs |
10 | MemoryStats, profiler, budget manager |
curriculum.rs |
11 | Linear, Exponential, SelfPaced, Competence, Task, Manager |
transfer.rs |
13 | Freezing, Progressive, Discriminative, FeatureExtractor |
hyperparameter.rs |
32 | Grid, Random, Bayesian Opt, GP, Acquisition |
crossval.rs |
12 | KFold, Stratified, TimeSeries, LeaveOneOut |
ensemble.rs |
22 | Voting, Averaging, Stacking, Bagging |
distillation.rs |
7 | Standard, Feature, Attention distillation |
label_smoothing.rs |
8 | Label smoothing, Mixup |
multitask.rs |
5 | Fixed, DTP, PCGrad |
data.rs |
12 | Dataset, CSV loader, preprocessor, encoders |
utils.rs |
11 | Model summary, gradient stats, time estimation |
quantization.rs |
14 | INT8/4/2, PTQ, QAT, calibration |
mixed_precision.rs |
14 | FP16/BF16, loss scaling, master weights |
gradient_centralization.rs |
14 | LayerWise, Global, PerRow, PerColumn |
structured_logging.rs |
4 | Builder, formats, levels |
few_shot.rs |
13 | Prototypical, Matching networks, distances |
meta_learning.rs |
15 | MAML, Reptile, task management |
pruning.rs |
13 | Magnitude, Gradient, Structured, Global, Iterative |
sampling.rs |
14 | Hard negative mining, Importance, Focal, Class balanced |
| Total | 499 | 100% |
Run tests with:
Guides and Documentation
Comprehensive guides are available in the docs/ directory:
-
Loss Function Selection Guide - Choose the right loss for your task
- Decision trees and comparison tables
- Detailed explanations of all 15 loss functions
- Metric learning losses (Contrastive, Triplet)
- Classification losses (Hinge, KL Divergence, Poly)
- Best practices and common pitfalls
-
Hyperparameter Tuning Guide - Optimize training performance
- Learning rate tuning (with LR finder)
- Batch size selection
- Optimizer comparison and selection
- Learning rate schedules
- Regularization strategies
Examples
The crate includes comprehensive examples demonstrating all features:
- 01_basic_training.rs - Simple regression with SGD
- 02_classification_with_metrics.rs - Multi-class classification
- 03_callbacks_and_checkpointing.rs - Advanced callbacks and training management
- 04_logical_loss_training.rs - Constraint-based training
- 05_profiling_and_monitoring.rs - Performance profiling
- 06_curriculum_learning.rs - Progressive difficulty training
- 07_transfer_learning.rs - Fine-tuning strategies
- 08_hyperparameter_optimization.rs - Grid/random search
- 09_cross_validation.rs - Robust model evaluation
- 10_ensemble_learning.rs - Model ensembling
- 11_advanced_integration.rs - Complete workflow integration
- 12_knowledge_distillation.rs - Model compression
- 13_label_smoothing.rs - Regularization techniques
- 14_multitask_learning.rs - Multi-task training
- 15_training_recipes.rs - Complete end-to-end workflows
- 16_structured_logging.rs - Production-grade observability
- 17_few_shot_learning.rs - Learning from minimal examples
- 18_meta_learning.rs - MAML and Reptile algorithms
- 21_bayesian_optimization.rs - Bayesian hyperparameter search
- 22_gradient_centralization.rs - Gradient preprocessing
Run any example with:
Benchmarks
Performance benchmarks are available in the benches/ directory:
Benchmarks cover:
- Optimizer comparison (SGD, Adam, AdamW)
- Batch size scaling
- Dataset size scaling
- Model size scaling
- Gradient clipping overhead
License
Apache-2.0
Contributing
See CONTRIBUTING.md for guidelines.
References
Status: Production Ready (v0.1.0-rc.1) Last Updated: 2026-03-06 Version: 0.1.0-rc.1 Test Coverage: 499/499 tests passing (100%) Code Quality: Zero warnings, clippy clean Features: 15 losses, 18 optimizers, 12 schedulers, 14+ callbacks, 9 regularization techniques, 9 augmentations, DropPath, DropBlock, quantization, mixed precision, few-shot learning, meta-learning, Bayesian optimization, gradient centralization Examples: 20 comprehensive training examples