axonml-llm
Overview
axonml-llm provides implementations of popular transformer-based large language model architectures for the AxonML framework. This crate includes complete implementations of BERT and GPT-2 models, along with modular building blocks for constructing custom LLM architectures.
Built on top of axonml-tensor and axonml-autograd, this crate enables training and inference of transformer models entirely in pure Rust.
Features
-
BERT Implementation - Full Bidirectional Encoder Representations from Transformers with support for sequence classification, masked language modeling, and custom heads.
-
GPT-2 Implementation - Complete Generative Pre-trained Transformer 2 with all model sizes (Small, Medium, Large, XL) and language modeling head.
-
Multi-Head Attention - Efficient multi-head self-attention and causal self-attention implementations with configurable heads and dimensions.
-
Transformer Building Blocks - Modular encoder and decoder blocks with layer normalization, feed-forward networks, and residual connections.
-
Embedding Layers - Token embeddings, learned positional embeddings, sinusoidal positional encodings, and BERT/GPT-2 combined embeddings.
-
Text Generation - Comprehensive generation utilities including greedy decoding, temperature sampling, top-k/top-p filtering, and beam search.
-
Configurable Architectures - Pre-defined configurations for BERT-base, BERT-large, GPT-2 Small/Medium/Large/XL, and tiny variants for testing.
Modules
| Module | Description |
|---|---|
attention |
Multi-head self-attention and causal self-attention mechanisms |
bert |
BERT model with classification and masked LM variants |
config |
Configuration structs for BERT, GPT-2, and base transformers |
embedding |
Token, positional, and combined embedding layers |
error |
Error types and result definitions for LLM operations |
generation |
Text generation utilities, sampling strategies, and beam search |
gpt2 |
GPT-2 model with language modeling head |
transformer |
Encoder/decoder blocks, layer norm, and feed-forward networks |
Usage
Add the crate to your Cargo.toml:
[]
= "0.1.0"
GPT-2 Text Generation
use ;
use Tensor;
// Create a GPT-2 model
let config = small;
let model = new;
// Input token IDs
let input_ids = from_vec.unwrap;
// Generate text with sampling
let output = model.generate;
println!;
// Or use greedy decoding
let greedy_output = model.generate_greedy;
BERT for Sequence Classification
use ;
use Tensor;
// Create BERT for binary classification
let config = base;
let model = new;
// Input token IDs
let input_ids = from_vec.unwrap;
// Get classification logits
let logits = model.forward_classification;
println!;
Custom Transformer Encoder
use ;
use Variable;
use Tensor;
// Build a custom encoder stack
let encoder = new;
// Forward pass
let input = new;
let output = encoder.forward;
Generation Configuration
use ;
// Configure generation parameters
let config = nucleus_sampling
.with_max_tokens
.with_repetition_penalty
.with_eos_token;
let generator = new;
// Use with model logits
let next_token = generator.get_next_token;
BERT for Masked Language Modeling
use ;
use Tensor;
// Create BERT for MLM
let config = base;
let model = new;
// Input with [MASK] token
let input_ids = from_vec.unwrap;
// Get MLM logits
let logits = model.forward_mlm;
// Shape: [batch, seq_len, vocab_size]
Model Configurations
BERT Configurations
| Config | Hidden Size | Layers | Heads | Parameters |
|---|---|---|---|---|
BertConfig::tiny() |
128 | 2 | 2 | ~4M |
BertConfig::base() |
768 | 12 | 12 | ~110M |
BertConfig::large() |
1024 | 24 | 16 | ~340M |
GPT-2 Configurations
| Config | Embedding Dim | Layers | Heads | Parameters |
|---|---|---|---|---|
GPT2Config::tiny() |
128 | 2 | 2 | ~4M |
GPT2Config::small() |
768 | 12 | 12 | ~117M |
GPT2Config::medium() |
1024 | 24 | 16 | ~345M |
GPT2Config::large() |
1280 | 36 | 20 | ~774M |
GPT2Config::xl() |
1600 | 48 | 25 | ~1.5B |
Generation Strategies
| Strategy | Method | Description |
|---|---|---|
| Greedy | GenerationConfig::greedy() |
Always selects highest probability token |
| Sampling | GenerationConfig::sampling(temp) |
Temperature-controlled sampling |
| Top-K | GenerationConfig::top_k_sampling(k, temp) |
Sample from top-k tokens |
| Nucleus | GenerationConfig::nucleus_sampling(p, temp) |
Sample from top-p probability mass |
| Beam Search | GenerationConfig::beam_search(beams) |
Beam search decoding |
Tests
Run the test suite:
Run with verbose output:
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.