axonml-llm
Overview
axonml-llm provides nine large-language-model architectures for the AxonML framework, all implemented in pure Rust on top of axonml-tensor and axonml-autograd. Shared infrastructure includes multi-head / causal self-attention with a KV cache, a FlashAttention kernel, RoPE, RMSNorm, a HuggingFace weight loader (safetensors), a state-dict name-mapping helper, a pretrained-weights hub with on-disk caching, a configurable text-generation sampler, and an HF-style tokenizer.
Supported architectures
| Architecture |
Module |
Notes |
| GPT-2 |
gpt2 |
Decoder-only transformer with learned positional embeddings and GPT2LMHead. |
| BERT |
bert |
Bidirectional encoder with BertForSequenceClassification and BertForMaskedLM. |
| LLaMA |
llama |
LLaMA-2/3-style: split-halves RoPE, grouped-query attention, SwiGLU MLP, RMSNorm. LLaMAForCausalLM. |
| Mistral |
mistral |
LLaMA + sliding-window attention. MistralForCausalLM. |
| Phi |
phi |
Phi-1/2/3-style: partial RoPE, GELU MLP, optional parallel-attn block. PhiForCausalLM. |
| SSM |
ssm |
Mamba/S6-style selective state-space model: depthwise Conv1d + selective scan + RMSNorm. SSMBlock, SSMForCausalLM. |
| Hydra |
hydra |
Hybrid: alternates SSM blocks and windowed (local) attention. HydraModel. |
| Chimera |
chimera |
Sparse MoE (top-k routing) + differential attention, with load-balancing auxiliary loss. ChimeraModel. |
| Trident |
trident |
1.58-bit ternary weights (TernaryLinear) + RoPE + GQA + ReLU²-gated FFN + SubLN (BitNet b1.58-2B-4T recipe). TridentModel. |
Preset configurations
| Config |
Presets |
GPT2Config |
tiny, small, medium, large, xl |
BertConfig |
tiny, base, large |
LLaMAConfig |
llama2_7b, llama2_13b, llama3_8b, tiny |
MistralConfig |
mistral_7b, mistral_7b_instruct, mixtral_8x7b, tiny |
PhiConfig |
phi1, phi2, phi3_mini, tiny |
SSMConfig |
from_d_model(d_model, vocab_size) builder |
HydraConfig |
base (~300M), small, tiny |
ChimeraConfig |
default_2b (8 experts, top-2), small, tiny |
TridentConfig |
default_150m, tiny, medium, plus 1B/3B/smoke constructors exposed for training |
Shared building blocks
- Attention (
attention) — MultiHeadSelfAttention, CausalSelfAttention, scaled_dot_product_attention, plus FlashAttention + FlashAttentionConfig. KVCache and per-layer LayerKVCache for incremental decoding.
- Transformer (
transformer) — TransformerBlock, TransformerEncoder, TransformerDecoder with configurable depth/width/heads/activation and pre- or post-norm.
- Embeddings (
embedding) — TokenEmbedding, PositionalEmbedding (sinusoidal), GPT2Embedding, BertEmbedding.
- Generation (
generation) — GenerationConfig with greedy(), sampling(temp), top_k_sampling(k, temp), nucleus_sampling(p, temp), beam_search(beams), plus builder methods with_max_tokens, with_eos_token, with_repetition_penalty. TextGenerator drives the next-token logic.
- Weight loading —
HFLoader (safetensors), load_llama_from_hf, load_mistral_from_hf, and generic LoadStateDict trait with LoadResult + map_hf_to_axonml / map_axonml_to_hf name mappers.
- Hub —
PretrainedLLM registry (llm_registry(), list_models()), download_weights(name, force) with an on-disk cache under $XDG_CACHE_HOME/axonml/hub/llm.
- Tokenizer —
HFTokenizer + SpecialTokens (HuggingFace-compatible tokenizer.json).
Modules
| Module |
Description |
attention |
Multi-head / causal / flash attention, KV cache |
bert |
BERT encoder + classification and MLM heads |
chimera |
Sparse MoE + differential attention model |
config |
GPT2Config, BertConfig, TransformerConfig |
embedding |
Token, positional, BERT/GPT-2 combined embeddings |
error |
LLMError / LLMResult |
generation |
GenerationConfig, TextGenerator, sampling strategies |
gpt2 |
GPT-2 + GPT2LMHead |
hf_loader |
HuggingFace safetensors loading (LLaMA, Mistral, …) |
hub |
Pretrained-weight registry and downloader |
hydra |
SSM + windowed-attention hybrid |
llama |
LLaMA with RoPE, GQA, SwiGLU, RMSNorm |
mistral |
Mistral (LLaMA + sliding-window attention) |
phi |
Phi with partial RoPE / GELU / optional parallel-attn |
ssm |
Mamba-style selective SSM blocks and LM head |
state_dict |
LoadStateDict trait, HF ↔ AxonML name mapping |
tokenizer |
HuggingFace-compatible HFTokenizer |
transformer |
Encoder / decoder blocks, feed-forward, layer norm |
trident |
1.58-bit ternary SLM (BitNet b1.58 recipe) |
Usage
Add the crate to your Cargo.toml:
[dependencies]
axonml-llm = "0.6.1"
GPT-2 text generation
use axonml_llm::{GPT2LMHead, GPT2Config};
use axonml_tensor::Tensor;
let config = GPT2Config::small();
let model = GPT2LMHead::new(&config);
let input_ids = Tensor::from_vec(vec![50256u32, 1, 2, 3], &[1, 4]).unwrap();
let output = model.generate(&input_ids, 50, 0.8, Some(50));
let greedy = model.generate_greedy(&input_ids, 50);
BERT sequence classification
use axonml_llm::{BertForSequenceClassification, BertConfig};
use axonml_tensor::Tensor;
let config = BertConfig::base();
let model = BertForSequenceClassification::new(&config, 2);
let input_ids = Tensor::from_vec(vec![101u32, 2054, 2003, 1996, 102], &[1, 5]).unwrap();
let logits = model.forward_classification(&input_ids);
BERT masked language modeling
use axonml_llm::{BertForMaskedLM, BertConfig};
use axonml_tensor::Tensor;
let model = BertForMaskedLM::new(&BertConfig::base());
let input_ids = Tensor::from_vec(vec![101u32, 2054, 103, 1996, 102], &[1, 5]).unwrap();
let logits = model.forward_mlm(&input_ids);
LLaMA / Mistral / Phi
use axonml_llm::{LLaMAForCausalLM, LLaMAConfig, MistralForCausalLM, MistralConfig, PhiForCausalLM, PhiConfig};
let llama = LLaMAForCausalLM::new(&LLaMAConfig::llama2_7b());
let mistral = MistralForCausalLM::new(&MistralConfig::mistral_7b());
let phi = PhiForCausalLM::new(&PhiConfig::phi2());
SSM (Mamba-style) causal LM
use axonml_llm::{SSMForCausalLM, SSMConfig};
let cfg = SSMConfig::from_d_model( 512, 32000);
let ssm = SSMForCausalLM::new(&cfg);
Hydra hybrid SSM + windowed attention
use axonml_llm::{HydraModel, HydraConfig};
let hydra = HydraModel::new(&HydraConfig::base());
Chimera sparse MoE + differential attention
use axonml_llm::{ChimeraModel, ChimeraConfig};
let cfg = ChimeraConfig::default_2b(); let model = ChimeraModel::new(&cfg);
Trident 1.58-bit ternary LM
use axonml_llm::{TridentModel, TridentConfig};
let cfg = TridentConfig::default_150m(); let model = TridentModel::new(&cfg);
Custom transformer encoder
use axonml_llm::TransformerEncoder;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
let encoder = TransformerEncoder::new(
6, 512, 8, 2048, 0.1, 1e-12, "gelu", false, );
let input = Variable::new(Tensor::randn(&[2, 128, 512]), false);
let output = encoder.forward(&input);
Generation configuration
use axonml_llm::{GenerationConfig, TextGenerator};
let config = GenerationConfig::nucleus_sampling(0.95, 0.8)
.with_max_tokens(100)
.with_repetition_penalty(1.2)
.with_eos_token(50256);
let generator = TextGenerator::new(config);
let next_token = generator.get_next_token(&logits, &generated_so_far);
Loading HuggingFace weights
use axonml_llm::{load_llama_from_hf, LLaMAConfig};
let mut model = axonml_llm::LLaMAForCausalLM::new(&LLaMAConfig::llama2_7b());
load_llama_from_hf(&mut model, "path/to/model.safetensors")?;
Pretrained hub
use axonml_llm::{llm_registry, download_llm_weights};
for m in llm_registry().values() {
println!("{} ({} params)", m.name, m.num_parameters);
}
let path = download_llm_weights("bert-base-uncased", false)?;
Configuration reference
BERT
| Config |
Hidden size |
Layers |
Heads |
BertConfig::tiny() |
128 |
2 |
2 |
BertConfig::base() |
768 |
12 |
12 |
BertConfig::large() |
1024 |
24 |
16 |
GPT-2
| Config |
Embedding dim |
Layers |
Heads |
GPT2Config::tiny() |
128 |
2 |
2 |
GPT2Config::small() |
768 |
12 |
12 |
GPT2Config::medium() |
1024 |
24 |
16 |
GPT2Config::large() |
1280 |
36 |
20 |
GPT2Config::xl() |
1600 |
48 |
25 |
LLaMA / Mistral / Phi
| Config |
Presets |
| LLaMA |
llama2_7b, llama2_13b, llama3_8b, tiny |
| Mistral |
mistral_7b, mistral_7b_instruct, mixtral_8x7b, tiny |
| Phi |
phi1, phi2, phi3_mini, tiny |
Hybrid / SSM / MoE / Ternary
| Config |
Presets / builder |
SSMConfig |
from_d_model(d_model, vocab) |
HydraConfig |
base, small, tiny |
ChimeraConfig |
default_2b, small, tiny |
TridentConfig |
default_150m, tiny, medium (plus 1B/3B/smoke constructors) |
Generation strategies
| Strategy |
Method |
Description |
| Greedy |
GenerationConfig::greedy() |
Always takes the argmax |
| Sampling |
GenerationConfig::sampling(temp) |
Temperature-scaled softmax sampling |
| Top-K |
GenerationConfig::top_k_sampling(k, temp) |
Sample from top-k tokens |
| Nucleus |
GenerationConfig::nucleus_sampling(p, temp) |
Sample from top-p probability mass |
| Beam Search |
GenerationConfig::beam_search(beams) |
Beam search decoding |
Tests
cargo test -p axonml-llm
cargo test -p axonml-llm -- --nocapture
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.
Last updated: 2026-04-16 (v0.6.1)