Skip to main content

axonml_llm/
lib.rs

1//! axonml-llm - Large Language Model Architectures
2//!
3//! This crate provides implementations of popular transformer-based language models
4//! including BERT, GPT-2, LLaMA, Mistral, and Phi, along with building blocks for
5//! custom LLM architectures.
6//!
7//! # Key Features
8//! - BERT (Bidirectional Encoder Representations from Transformers)
9//! - GPT-2 (Generative Pre-trained Transformer 2)
10//! - LLaMA (Large Language Model Meta AI) with RoPE and SwiGLU
11//! - Mistral with sliding window attention
12//! - Phi with partial rotary embeddings
13//! - KV-cache for efficient autoregressive generation
14//! - Transformer building blocks (attention, feed-forward, positional encoding)
15//! - Text generation utilities
16//!
17//! # Example
18//! ```ignore
19//! use axonml_llm::{GPT2, GPT2Config};
20//! use axonml_tensor::Tensor;
21//!
22//! // Create a GPT-2 model
23//! let config = GPT2Config::small();
24//! let model = GPT2::new(&config);
25//!
26//! // Generate text
27//! let input_ids = Tensor::from_vec(vec![50256u32], &[1, 1]).unwrap();
28//! let output = model.forward(&input_ids);
29//! ```
30//!
31//! @version 0.2.0
32//! @author AutomataNexus Development Team
33
34#![warn(missing_docs)]
35#![warn(clippy::all)]
36
37pub mod error;
38pub mod config;
39pub mod attention;
40pub mod embedding;
41pub mod hub;
42pub mod hf_loader;
43pub mod tokenizer;
44pub mod state_dict;
45pub mod transformer;
46pub mod bert;
47pub mod gpt2;
48pub mod llama;
49pub mod mistral;
50pub mod phi;
51pub mod generation;
52
53pub use error::{LLMError, LLMResult};
54pub use config::{BertConfig, GPT2Config, TransformerConfig};
55pub use attention::{
56    CausalSelfAttention, FlashAttention, FlashAttentionConfig, KVCache, LayerKVCache,
57    MultiHeadSelfAttention, scaled_dot_product_attention,
58};
59pub use embedding::{TokenEmbedding, PositionalEmbedding, BertEmbedding, GPT2Embedding};
60pub use hub::{PretrainedLLM, llm_registry, download_weights as download_llm_weights};
61pub use hf_loader::{HFLoader, load_llama_from_hf, load_mistral_from_hf};
62pub use tokenizer::{HFTokenizer, SpecialTokens};
63pub use state_dict::{LoadStateDict, LoadResult};
64pub use transformer::{TransformerBlock, TransformerEncoder, TransformerDecoder};
65pub use bert::{Bert, BertForSequenceClassification, BertForMaskedLM};
66pub use gpt2::{GPT2, GPT2LMHead};
67pub use llama::{LLaMA, LLaMAConfig, LLaMAForCausalLM};
68pub use mistral::{Mistral, MistralConfig, MistralForCausalLM};
69pub use phi::{Phi, PhiConfig, PhiForCausalLM};
70pub use generation::{GenerationConfig, TextGenerator};
71
72// =============================================================================
73// Tests
74// =============================================================================
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_gpt2_config() {
82        let config = GPT2Config::small();
83        assert_eq!(config.n_layer, 12);
84        assert_eq!(config.n_head, 12);
85        assert_eq!(config.n_embd, 768);
86    }
87
88    #[test]
89    fn test_bert_config() {
90        let config = BertConfig::base();
91        assert_eq!(config.num_hidden_layers, 12);
92        assert_eq!(config.num_attention_heads, 12);
93        assert_eq!(config.hidden_size, 768);
94    }
95}