torsh-models
Pre-built model architectures and model zoo for ToRSh.
Overview
This crate provides ready-to-use model architectures for various domains:
- Computer Vision: ResNet, EfficientNet, Vision Transformer, etc.
- Natural Language Processing: BERT, GPT, T5, etc.
- Audio Processing: Wav2Vec2, Whisper, etc.
- Multimodal: CLIP, DALL-E, etc.
- Model Utilities: Weight initialization, model surgery, pruning
Usage
Vision Models
use *;
// ResNet variants
let resnet18 = resnet18?;
let resnet50 = resnet50?;
let resnet101 = resnet101?;
// Custom configuration
let custom_resnet = new?;
// EfficientNet family
let efficientnet_b0 = efficientnet_b0?;
let efficientnet_b7 = efficientnet_b7?;
// Vision Transformer
let vit_b_16 = vit_base_patch16_224?;
let vit_l_32 = vit_large_patch32_384?;
// Object Detection
let faster_rcnn = fasterrcnn_resnet50_fpn?;
let mask_rcnn = maskrcnn_resnet50_fpn?;
// Segmentation
let deeplabv3 = deeplabv3_resnet101?;
let fcn = fcn_resnet50?;
NLP Models
use *;
// BERT variants
let bert_base = bert_base_uncased?;
let bert_large = bert_large_cased?;
// Custom BERT configuration
let custom_bert = new?;
// GPT models
let gpt2 = gpt2?;
let gpt2_medium = gpt2_medium?;
// T5 models
let t5_small = t5_small?;
let t5_base = t5_base?;
// For specific tasks
let bert_classifier = new?;
Audio Models
use *;
// Wav2Vec2
let wav2vec2 = wav2vec2_base?;
// Whisper
let whisper_base = whisper_base?;
let whisper_large = whisper_large?;
// Audio classification
let audio_classifier = new?;
Multimodal Models
use *;
// CLIP
let clip = clip_vit_base_patch32?;
let = clip.forward?;
// Flamingo
let flamingo = flamingo_base?;
Model Utilities
use *;
// Weight initialization
init_weights?;
// Model surgery
let pruned_model = prune_model?;
// Knowledge distillation
let student = distill_model?;
// Model conversion
let quantized = quantize_model?;
let onnx_model = export_onnx?;
Transfer Learning
use *;
// Fine-tune pre-trained model
let base_model = resnet50?;
let feature_extractor = remove_head;
let new_model = add_custom_head?;
// Freeze base layers
freeze_layers?;
// Progressive unfreezing
let scheduler = new
.unfreeze_at
.unfreeze_at;
Model Configuration
use *;
// Load configuration
let config = from_pretrained?;
// Modify configuration
config.hidden_size = 1024;
config.num_attention_heads = 16;
// Create model from config
let model = from_config?;
// Save configuration
config.save?;
Model Registry
use *;
// Register custom model
register_model;
// List available models
let models = list_models?;
for in models
// Load by name
let model = load_model?;
Benchmarking
use *;
// Benchmark inference speed
let results = benchmark_model?;
println!;
println!;
// Compare models
let comparison = compare_models?;
Tutorials
Tutorial 1: Image Classification with Pre-trained ResNet
use *;
use get_global_registry;
use Tensor;
// Load a pre-trained ResNet model
let registry = get_global_registry;
let model_handle = registry.get_model_handle?;
// Create the model
let mut model = resnet18?;
model.eval; // Set to evaluation mode
// Prepare input (batch of RGB images)
let batch_size = 4;
let input = randn?;
// Forward pass
let output = model.forward?;
let predictions = output.softmax?;
// Get top-5 predictions
let = predictions.topk?;
println!;
for i in 0..batch_size
Tutorial 2: Text Classification with BERT
use *;
use Tensor;
// Create BERT model for sequence classification
let config = BertConfig ;
let mut model = new?;
model.eval;
// Prepare tokenized input (token IDs)
let batch_size = 2;
let seq_len = 128;
let input_ids = randint?;
let attention_mask = ones?;
// Forward pass
let output = model.forward?;
let logits = output.logits;
let predictions = logits.softmax?;
// Get predicted classes
let predicted_classes = predictions.argmax?;
println!;
Tutorial 3: Speech Recognition with Whisper
use *;
use Tensor;
// Load Whisper model
let config = base;
let mut model = new?;
model.eval;
// Prepare mel spectrogram input
let batch_size = 1;
let n_mels = 80;
let seq_len = 3000;
let input_features = randn?;
// Generate transcription
let decoder_input_ids = tensor?: // Start token
let output = model.generate?;
println!;
Tutorial 4: Vision-Language Understanding with CLIP
use *;
use Tensor;
// Load CLIP model
let config = default;
let mut model = new?;
model.eval;
// Prepare inputs
let batch_size = 4;
let image_input = randn?;
let text_input = randint?; // Tokenized text
// Get embeddings
let image_features = model.get_image_features?;
let text_features = model.get_text_features?;
// Compute similarity
let similarity = image_features.matmul?;
let probs = similarity.softmax?;
println!;
Tutorial 5: Fine-tuning for Custom Dataset
use *;
use SGD;
use functional as F;
// Load pre-trained model and modify for custom task
let mut base_model = resnet18?;
// Replace classifier head for custom number of classes
let num_custom_classes = 10;
let in_features = 512; // ResNet18 final layer input size
let custom_head = new?;
base_model.fc = custom_head;
// Set up optimizer
let mut optimizer = SGDnew?;
// Training loop
for epoch in 0..10
Tutorial 6: Model Quantization and Optimization
use *;
use *;
// Load pre-trained model
let mut model = resnet50?;
model.eval;
// Prepare calibration data
let calibration_data = vec!;
// Configure quantization
let quant_config = QuantizationConfig ;
// Quantize the model
let mut quantizer = new;
let quantized_model = quantizer.quantize?;
// Benchmark original vs quantized
let input = randn?;
let start = now;
let output1 = model.forward?;
let original_time = start.elapsed;
let start = now;
let output2 = quantized_model.forward?;
let quantized_time = start.elapsed;
println!;
println!;
println!;
Tutorial 7: Model Ensembling
use *;
use *;
// Create multiple models
let model1 = resnet18?;
let model2 = resnet34?;
let model3 = efficientnet_b0?;
// Create ensemble
let ensemble_config = EnsembleConfig ;
let mut ensemble = new;
ensemble.add_model?;
ensemble.add_model?;
ensemble.add_model?;
// Inference with ensemble
let input = randn?;
let ensemble_output = ensemble.forward?;
let predictions = ensemble_output.softmax?;
println!;
Migration Guide
Migrating from PyTorch
ToRSh models are designed to be similar to PyTorch for easy migration. Here are common patterns:
Model Creation
PyTorch:
# Load pre-trained model
=
ToRSh:
use resnet;
// Load pre-trained model
let mut model = resnet18?;
model.eval;
Forward Pass
PyTorch:
=
=
ToRSh:
let output = model.forward?;
let predictions = output.softmax?;
Model Configuration
PyTorch:
=
=
ToRSh:
use bert;
let config = BertConfig ;
let model = new?;
Training Loop
PyTorch:
=
=
=
=
ToRSh:
use SGD;
use functional as F;
let mut optimizer = SGDnew?;
for epoch in 0..num_epochs
Migrating from TensorFlow/Keras
Sequential Model
TensorFlow/Keras:
=
ToRSh:
use *;
let model = new
.add
.add
.add
.add
.add
.add
.add;
Model Compilation and Training
TensorFlow/Keras:
ToRSh:
use Adam;
use functional as F;
let mut optimizer = new?;
for epoch in 0..10
Common Migration Patterns
Error Handling
Python (with exceptions):
=
=
Rust (with Result types):
match load_model
// Or using the ? operator for cleaner code:
let model = load_model?;
let output = model.forward?;
Device Management
PyTorch:
=
=
=
ToRSh:
use Device;
let device = if cuda_is_available else ;
let model = model.to_device?;
let input_tensor = input_tensor.to_device?;
Model Saving and Loading
PyTorch:
# Save
# Load
ToRSh:
// Save
model.save?;
// Load
let model = load?;
Key Differences to Note
- Memory Safety: ToRSh provides compile-time memory safety guarantees
- Error Handling: Rust uses
Result<T, E>instead of exceptions - Ownership: Rust's ownership system requires explicit handling of data movement
- Immutability: Variables are immutable by default, use
mutfor mutable variables - Type Safety: Strong static typing catches errors at compile time
- Performance: Zero-cost abstractions and no garbage collection
Best Practices for Migration
- Start Small: Begin with simple models and gradually increase complexity
- Use Type Annotations: Leverage Rust's type system for better code clarity
- Handle Errors Properly: Use
?operator and proper error handling patterns - Leverage Rust Tools: Use
cargo clippyfor linting andcargo fmtfor formatting - Test Thoroughly: Write unit tests to ensure model behavior matches expectations
- Use the Registry: Take advantage of the built-in model registry for pretrained models
Available Models
Vision
- ResNet (18, 34, 50, 101, 152)
- ResNeXt (50, 101)
- Wide ResNet
- EfficientNet (B0-B7)
- MobileNet (V2, V3)
- VGG (11, 13, 16, 19)
- DenseNet (121, 161, 169, 201)
- Vision Transformer (ViT)
- Swin Transformer
- ConvNeXt
NLP
- BERT (Base, Large)
- RoBERTa
- GPT-2 (Small, Medium, Large, XL)
- T5 (Small, Base, Large)
- BART
- XLNet
- ELECTRA
Audio
- Wav2Vec2
- Whisper
- HuBERT
- WavLM
Detection & Segmentation
- Faster R-CNN
- Mask R-CNN
- YOLO (v5, v8)
- DETR
- DeepLabV3
- U-Net
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.