1#![warn(missing_docs)]
35#![warn(clippy::all)]
36
37pub mod error;
38pub mod config;
39pub mod attention;
40pub mod embedding;
41pub mod hub;
42pub mod hf_loader;
43pub mod tokenizer;
44pub mod state_dict;
45pub mod transformer;
46pub mod bert;
47pub mod gpt2;
48pub mod llama;
49pub mod mistral;
50pub mod phi;
51pub mod generation;
52
53pub use error::{LLMError, LLMResult};
54pub use config::{BertConfig, GPT2Config, TransformerConfig};
55pub use attention::{
56 CausalSelfAttention, FlashAttention, FlashAttentionConfig, KVCache, LayerKVCache,
57 MultiHeadSelfAttention, scaled_dot_product_attention,
58};
59pub use embedding::{TokenEmbedding, PositionalEmbedding, BertEmbedding, GPT2Embedding};
60pub use hub::{PretrainedLLM, llm_registry, download_weights as download_llm_weights};
61pub use hf_loader::{HFLoader, load_llama_from_hf, load_mistral_from_hf};
62pub use tokenizer::{HFTokenizer, SpecialTokens};
63pub use state_dict::{LoadStateDict, LoadResult};
64pub use transformer::{TransformerBlock, TransformerEncoder, TransformerDecoder};
65pub use bert::{Bert, BertForSequenceClassification, BertForMaskedLM};
66pub use gpt2::{GPT2, GPT2LMHead};
67pub use llama::{LLaMA, LLaMAConfig, LLaMAForCausalLM};
68pub use mistral::{Mistral, MistralConfig, MistralForCausalLM};
69pub use phi::{Phi, PhiConfig, PhiForCausalLM};
70pub use generation::{GenerationConfig, TextGenerator};
71
72#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_gpt2_config() {
82 let config = GPT2Config::small();
83 assert_eq!(config.n_layer, 12);
84 assert_eq!(config.n_head, 12);
85 assert_eq!(config.n_embd, 768);
86 }
87
88 #[test]
89 fn test_bert_config() {
90 let config = BertConfig::base();
91 assert_eq!(config.num_hidden_layers, 12);
92 assert_eq!(config.num_attention_heads, 12);
93 assert_eq!(config.hidden_size, 768);
94 }
95}