Embedding Trainer
A fast and flexible Rust library and CLI tool for training word embeddings from scratch using Skip-gram and CBOW algorithms with built-in validation and evaluation.
โจ Features
๐ Algorithms
- Skip-gram: Predicts context words given target words
- CBOW: Predicts target words given context words
๐ Training Features
- Configurable embedding dimensions
- Adjustable learning rates and epochs
- Customizable context windows
- Negative sampling support
- Batch processing capabilities
๐ง CLI Tools
- Training: Train embeddings from text data with optional validation split
- Similarity: Calculate semantic similarity between words
- Inspection: Analyze trained models and vocabulary
- Export: Save embeddings in multiple formats (text, JSON, binary, Word2Vec)
- Validate: Evaluate a saved model on held-out validation text
๐พ Data Support
- Text file processing
- Vocabulary management
- Model persistence
- Multiple export formats
- Streaming support for large datasets
๐ Quick Start
Installation
# Clone the repository
# Build the project
# Or install locally
Basic Usage
1. Train Your First Embeddings
# Prepare your training data
# Train embeddings using Skip-gram
# Train with validation split
2. Calculate Similarity
# Calculate similarity between words
# Expected output:
# Similarity between 'fox' and 'dog': 0.8234
3. Inspect Model
# View model information
# Shows vocabulary size, embedding dimension, training config
4. Export Embeddings
# Export to different formats
๐ Library Usage
Basic Example
use *;
Advanced Usage
use *;
use fs;
๐ง Configuration
Training Parameters
| Parameter | Description | Default Value | Range |
|---|---|---|---|
--dim |
Embedding dimension | 300 | 10-1000 |
--learning-rate |
Learning rate | 0.025 | 0.001-1.0 |
--epochs |
Number of training epochs | 10 | 1-1000 |
--batch-size |
Mini-batch size | 32 | 1-1000 |
--window |
Context window size | 5 | 1-20 |
--negative-samples |
Number of negative samples | 5 | 1-20 |
--validation-ratio |
Fraction of data for validation | 0.0 | 0.0-0.5 |
--validation-output |
File to write validation metrics JSON | - | - |
Algorithm Types
skipgram: Skip-gram algorithm (default)cbow: Continuous Bag of Words
Export Formats
text: Plain text format (default)json: JSON format with metadatabin: Binary format using bincodeword2vec: Word2Vec/Gensim text format
๐ CLI Reference
Training Command
Options:
--input <FILE>- Input text file (required)--output <FILE>- Output model file (required)--embeddings <FILE>- Embeddings output file (required)--dim <SIZE>- Embedding dimension (default: 300)--learning-rate <RATE>- Learning rate (default: 0.025)--epochs <COUNT>- Number of epochs (default: 10)--batch-size <SIZE>- Batch size (default: 32)--window <SIZE>- Context window size (default: 5)--negative-samples <COUNT>- Negative samples (default: 5)--model-type <TYPE>- Algorithm type (skipgram|cbow)--validation-ratio <RATIO>- Fraction for validation (default: 0.0)--validation-output <FILE>- Path to write validation metrics JSON
Similarity Command
Options:
--model <FILE>- Model file (required)
Info Command
Options:
--model <FILE>- Model file (required)
Export Command
Options:
--model <FILE>- Model file (required)--output <FILE>- Output file (required)--format <FORMAT>- Export format (text|json|bin|word2vec)
Validate Command
Options:
--model <FILE>- Model file (required)--input <FILE>- Validation text file (required)--output <FILE>- Output metrics JSON file (optional)
๐ Examples
Example 1: Basic Word Embeddings
# Create sample data
# Train embeddings
# Test similarity
Example 2: Document Embeddings
# Prepare document data
# Train with CBOW and validation
Example 3: Large Dataset Processing
# Process large file with multiple epochs
๐งช Development
Building from Source
# Clone repository
# Build development version
# Run tests
# Run benchmarks
# Build documentation
Running Tests
# Run all tests
# Run specific test
# Run with verbose output
Development Features
- Unit Tests: Comprehensive test coverage
- Integration Tests: End-to-end testing
- Benchmarks: Performance testing
- Documentation: API documentation
๐ Performance
Benchmarks
| Algorithm | Vocab Size | Embed Dim | Training Time | Memory Usage |
|---|---|---|---|---|
| Skip-gram | 10K words | 300 | 2.3s | 45MB |
| CBOW | 10K words | 300 | 1.8s | 42MB |
Optimization Tips
- Use appropriate batch sizes for your dataset
- Adjust learning rate based on dataset size
- Context window size affects training speed and quality
- Use negative sampling for large vocabularies
- Monitor memory usage with large datasets
๐ค Contributing
We welcome contributions! Please see our Contributing Guide for details.
Development Workflow
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests for new functionality
- Run the test suite
- Submit a pull request
Code Style
- Follow Rust formatting standards
- Use
cargo fmtfor code formatting - Add comprehensive documentation
- Include tests for new features
๐ Roadmap
Version 1.0 (Current)
- โ Skip-gram and CBOW algorithms
- โ CLI interface with train, validate, similarity, info, export
- โ Model persistence (JSON, binary, Word2Vec, ONNX, NumPy)
- โ Similarity calculations and semantic search
- โ Validation split and evaluation metrics (accuracy, precision, recall, F1)
- โ Learning rate scheduling (constant, exponential, step, cosine)
- โ Early stopping and L2 regularization
Version 1.1 (Planned)
- GPU acceleration
- Advanced tokenization improvements
- Cross-validation support
- Learning curve visualization
Version 2.0 (Future)
- Transformer-based models
- Multi-modal embeddings
- Real-time training
- Standard word similarity benchmarks integration
๐ Troubleshooting
Common Issues
-
Memory Error with Large Datasets
- Reduce batch size
- Use streaming processing
- Increase system memory
-
Poor Similarity Results
- Increase training epochs
- Adjust learning rate
- Try different algorithms
-
Missing Words in Vocabulary
- Check text preprocessing
- Verify tokenization
- Ensure words appear in text
Performance Issues
- Slow Training: Reduce batch size or use negative sampling
- High Memory Usage: Use smaller embedding dimensions
- Poor Quality: Increase epochs or adjust parameters
๐ License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
๐ Acknowledgments
- Inspired by Word2Vec, GloVe, and BERT
- Built with ndarray for numerical computing
- CLI powered by clap
- Serialization using serde
๐ Support
- ๐ง Email: your.email@example.com
- ๐ฌ Discussions: GitHub Discussions
- ๐ Issues: GitHub Issues
- ๐ Documentation: docs.rs/embedding-trainer
Made with โค๏ธ by the Embedding Trainer Team
For the latest updates, check our GitHub repository