Embedding Trainer
A fast and flexible Rust library and CLI tool for training word embeddings from scratch using Skip-gram and CBOW algorithms with built-in validation, evaluation, and semantic search.
โจ Features
๐ Algorithms
- Skip-gram: Predicts context words given target words
- CBOW: Predicts target words given context words
๐ Training Features
- Configurable embedding dimensions
- Adjustable learning rates and epochs
- Customizable context windows
- Negative sampling support
- Batch processing capabilities
- Learning rate scheduling (constant, exponential, step, cosine)
- Early stopping with configurable patience
- L2 regularization and gradient clipping
- Train/validation split with metrics export
- Per-epoch training history / learning curves (JSON export)
- K-fold cross-validation with averaged metrics
๐ง CLI Tools
- Training: Train embeddings from text data with optional validation split
- Similarity: Calculate semantic similarity between words
- Inspection: Analyze trained models and vocabulary
- Export: Save embeddings in multiple formats (text, JSON, binary, Word2Vec)
- Validate: Evaluate a saved model on held-out validation text
- Interactive: Query trained models interactively (similarity, analogy, search)
๏ฟฝ Evaluation & Analysis
- Benchmarks: Evaluate against standard word similarity benchmarks (WordSim-353, SimLex-999) with Spearman correlation
- Clustering: K-means and hierarchical clustering of embeddings
- Cross-validation: K-fold cross-validation with per-fold metrics
- Learning curves: Per-epoch loss and learning rate tracking with JSON export
๏ฟฝ Data Support
- Text file processing with Unicode normalization
- Source code preprocessing (Rust, Python, JavaScript, etc.)
- BPE subword tokenization and FastText-style character n-grams
- WordPiece subword tokenization (BERT-style)
- Vocabulary management
- Model persistence
- Multiple export formats (JSON, binary, Word2Vec, ONNX, NumPy)
- Streaming support for large datasets
- Pluggable compute backend trait (CPU implemented, GPU ready)
๐ค Advanced Models
- Transformer encoder: Multi-head self-attention with position encoding for contextualized embeddings
- Multi-modal fusion: Concatenation, weighted average, attention fusion, projection fusion, cross-modal similarity
- Real-time training: Incremental updates and streaming micro-batch training without full retrain
๐ Quick Start
Installation
# Clone the repository
# Build the project
# Or install locally
GPU Acceleration (Optional)
Enable GPU compute via the gpu feature flag. This uses wgpu compute shaders and works on Vulkan, Metal, and DX12 backends without vendor-specific SDKs.
# Build with GPU support
# Install with GPU support
When the gpu feature is enabled, EmbeddingModel::new() automatically selects the best available backend (GPU if present, otherwise CPU). You can also explicitly create a GPU backend:
use ;
// Attempt GPU initialization; fails gracefully if no GPU is available
if let Ok = new
Note: GPU operations have CPU-GPU transfer overhead. For small models the CPU backend may still be faster. GPU acceleration shines with large batch matrix multiplications (
matmul).
Basic Usage
1. Train Your First Embeddings
# Prepare your training data
# Train embeddings using Skip-gram
# Train with validation split
2. Calculate Similarity
# Calculate similarity between words
# Expected output:
# Similarity between 'fox' and 'dog': 0.8234
3. Inspect Model
# View model information
# Shows vocabulary size, embedding dimension, training config
4. Export Embeddings
# Export to different formats
๐ Library Usage
Basic Example
use *;
Advanced Usage
use *;
๐ง Configuration
Training Parameters
| Parameter | Description | Default Value | Range |
|---|---|---|---|
--dim |
Embedding dimension | 300 | 10-1000 |
--learning-rate |
Learning rate | 0.025 | 0.001-1.0 |
--epochs |
Number of training epochs | 10 | 1-1000 |
--batch-size |
Mini-batch size | 32 | 1-1000 |
--window |
Context window size | 5 | 1-20 |
--negative-samples |
Number of negative samples | 5 | 1-20 |
--validation-ratio |
Fraction of data for validation | 0.0 | 0.0-0.5 |
--validation-output |
File to write validation metrics JSON | - | - |
Algorithm Types
skipgram: Skip-gram algorithm (default)cbow: Continuous Bag of Words
Export Formats
text: Plain text format (default)json: JSON format with metadatabin: Binary format using bincodeword2vec: Word2Vec/Gensim text format
๐ CLI Reference
Training Command
Options:
--input <FILE>- Input text file (required)--output <FILE>- Output model file (required)--embeddings <FILE>- Embeddings output file (required)--dim <SIZE>- Embedding dimension (default: 300)--learning-rate <RATE>- Learning rate (default: 0.025)--epochs <COUNT>- Number of epochs (default: 10)--batch-size <SIZE>- Batch size (default: 32)--window <SIZE>- Context window size (default: 5)--negative-samples <COUNT>- Negative samples (default: 5)--model-type <TYPE>- Algorithm type (skipgram|cbow)--validation-ratio <RATIO>- Fraction for validation (default: 0.0)--validation-output <FILE>- Path to write validation metrics JSON
Similarity Command
Options:
--model <FILE>- Model file (required)
Info Command
Options:
--model <FILE>- Model file (required)
Export Command
Options:
--model <FILE>- Model file (required)--output <FILE>- Output file (required)--format <FORMAT>- Export format (text|json|bin|word2vec)
Validate Command
Options:
--model <FILE>- Model file (required)--input <FILE>- Validation text file (required)--output <FILE>- Output metrics JSON file (optional)
๐ Examples
Example 1: Basic Word Embeddings
# Create sample data
# Train embeddings
# Test similarity
Example 2: Document Embeddings
# Prepare document data
# Train with CBOW and validation
Example 3: Large Dataset Processing
# Process large file with multiple epochs
๐งช Development
Building from Source
# Clone repository
# Build development version
# Run tests
# Run benchmarks
# Build documentation
Running Tests
# Run all tests
# Run specific test
# Run with verbose output
Development Features
- Unit Tests: Comprehensive test coverage
- Integration Tests: End-to-end testing
- Benchmarks: Performance testing
- Documentation: API documentation
๐ Performance
Benchmarks
| Algorithm | Vocab Size | Embed Dim | Training Time | Memory Usage |
|---|---|---|---|---|
| Skip-gram | 10K words | 300 | 2.3s | 45MB |
| CBOW | 10K words | 300 | 1.8s | 42MB |
Optimization Tips
- Use appropriate batch sizes for your dataset
- Adjust learning rate based on dataset size
- Context window size affects training speed and quality
- Use negative sampling for large vocabularies
- Monitor memory usage with large datasets
๐ค Contributing
We welcome contributions! Please see our Contributing Guide for details.
Development Workflow
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests for new functionality
- Run the test suite
- Submit a pull request
Code Style
- Follow Rust formatting standards
- Use
cargo fmtfor code formatting - Add comprehensive documentation
- Include tests for new features
๐ Roadmap
Version 1.0 (Current)
- โ Skip-gram and CBOW algorithms
- โ CLI interface with train, validate, similarity, info, export
- โ Model persistence (JSON, binary, Word2Vec, ONNX, NumPy)
- โ Similarity calculations and semantic search
- โ Validation split and evaluation metrics (accuracy, precision, recall, F1)
- โ Learning rate scheduling (constant, exponential, step, cosine)
- โ Early stopping and L2 regularization
Version 1.1 (Current โ Features Complete)
- โ Backend abstraction trait for GPU acceleration (CPU implemented)
- โ WordPiece subword tokenization
- โ K-fold cross-validation support
- โ Per-epoch training history / learning curve JSON export
- โ Standard word similarity benchmark evaluation (Spearman correlation)
- โ K-means clustering
- CUDA/OpenCL backend implementation (planned)
Version 2.0 (Current โ Features Complete)
- โ Transformer encoder with multi-head self-attention and position encoding
- โ Enhanced multi-modal fusion (attention fusion, projection fusion, cross-modal similarity)
- โ
Real-time incremental training (
IncrementalTrainerwith batch and stream modes)
๏ฟฝ Comparison with Alternatives
| Feature | embedding (this crate) | Gensim (Python) | rust-bert | fastText |
|---|---|---|---|---|
| Language | Rust | Python | Rust | C++ / Python |
| Algorithms | Skip-gram, CBOW, Transformer | Word2Vec, FastText, GloVe, LSI, LDA | BERT, RoBERTa, DistilBERT | Skip-gram, CBOW + subwords |
| WordPiece tokenization | โ | โ | โ | โ |
| BPE tokenization | โ | โ | โ | โ |
| GPU acceleration | โ (wgpu compute shaders, optional) | โ | โ (via ONNX / tch) | โ |
| Cross-validation | โ (k-fold) | โ | โ | โ |
| Learning curves | โ (per-epoch JSON export) | โ | โ | โ |
| Benchmark evaluation | โ (Spearman correlation) | โ (similarity tasks) | โ | โ |
| K-means clustering | โ | โ | โ | โ |
| Incremental training | โ (stream / batch updates) | โ (requires retrain) | โ | โ |
| Multi-modal fusion | โ (4 fusion strategies) | โ | โ | โ |
| CLI tool | โ (train, validate, search, export) | โ | โ | โ |
| Export formats | JSON, binary, Word2Vec, ONNX, NumPy | Word2Vec, Gensim native | ONNX | .vec, .bin |
| Memory mapping | โ (binary format) | โ | โ | โ |
| Pre-trained models | โ (Word2Vec text/binary, GloVe, fastText, mmap .bin) | โ (many built-in) | โ (Hugging Face) | โ |
| Sentence embeddings | โ (mean pooling) | โ (Doc2Vec) | โ (BERT pooling) | โ |
| Speed | โก Fast (Rust native) | ๐ Python overhead | โก Fast (Rust native) | โก Fast (C++) |
| Zero dependencies for inference | โ (after training) | โ (Gensim + NumPy + SciPy) | โ (ONNX / torch) | โ
(.vec format) |
Legend: โ = Supported | โ = Not supported | ๐ถ = Partial / planned
๏ฟฝ๏ฟฝ Troubleshooting
Common Issues
-
Memory Error with Large Datasets
- Reduce batch size
- Use streaming processing
- Increase system memory
-
Poor Similarity Results
- Increase training epochs
- Adjust learning rate
- Try different algorithms
-
Missing Words in Vocabulary
- Check text preprocessing
- Verify tokenization
- Ensure words appear in text
Performance Issues
- Slow Training: Reduce batch size or use negative sampling
- High Memory Usage: Use smaller embedding dimensions
- Poor Quality: Increase epochs or adjust parameters
๐ License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.