Skip to main content

alith_models/local_model/gguf/preset/
mod.rs

1//! Adding a preset model checklist:
2//! 1. Add a new variant to the `LLMPreset` enum via generate_models! macro
3//! 2. Create directory for the new model in `llm_client/llm_models/src/local_model/gguf/preset`
4//! 3. Add model_macro_data.json to the new model's directory
5//! 4. Add the model's config.json to the new model's directory
6//! 5. (Optional) Add the model's tokenizer_config.json to the new model's directory
7//! 6. (Optional) Add the model's tokenizer.json to the new model's directory
8//! 7. Add a test to llm_client/llm_models/tests/it/preset.rs for the new model
9//! 8. Add a test_base_generation_prefix test case to llm_client/llm_models/tests/it/metadata.rs for the new model
10use crate::local_model::{
11    GgufLoader, LocalLLMModel, gguf::loaders::preset::GgufPresetLoader,
12    hf_loader::HuggingFaceLoader, metadata::config_json::ConfigJson,
13};
14
15fn presets_path() -> std::path::PathBuf {
16    let cargo_manifest_dir = env!("CARGO_MANIFEST_DIR");
17    std::path::PathBuf::from(cargo_manifest_dir)
18        .join("src")
19        .join("local_model")
20        .join("gguf")
21        .join("preset")
22}
23
24#[derive(Debug, Clone, serde::Deserialize)]
25pub struct LLMPresetData {
26    pub model_id: String,
27    pub gguf_repo_id: String,
28    pub number_of_parameters: u64,
29    pub f_name_for_q_bits: QuantizationConfig,
30    pub tokenizer_preset_data: TokenizerPresetData,
31    pub tokenizer_config_preset_data: TokenizerConfigPresetData,
32}
33
34#[derive(Debug, Clone, serde::Deserialize)]
35pub struct TokenizerPresetData {
36    pub local_path: Option<String>,
37    pub hf_repo: Option<String>,
38    pub hf_filename: Option<String>,
39}
40impl TokenizerPresetData {
41    pub fn load(&self, hf_loader: &HuggingFaceLoader) -> crate::Result<std::path::PathBuf> {
42        if let Some(local_path) = self.local_path.clone() {
43            let path = presets_path().join(local_path);
44            match std::fs::File::open(&path) {
45                Ok(_) => Ok(path),
46                Err(_) => crate::bail!("Failed to open tokenizer.json at {}", path.display()),
47            }
48        } else if let (Some(hf_repo), Some(hf_filename)) =
49            (self.hf_repo.clone(), self.hf_filename.clone())
50        {
51            hf_loader.load_file(hf_filename, hf_repo)
52        } else {
53            crate::bail!("No local tokenizer.json, or hf_repo and hf_filename provided")
54        }
55    }
56}
57#[derive(Debug, Clone, serde::Deserialize)]
58pub struct TokenizerConfigPresetData {
59    pub local_path: Option<String>,
60    pub hf_repo: Option<String>,
61    pub hf_filename: Option<String>,
62}
63impl TokenizerConfigPresetData {
64    pub fn load(&self, hf_loader: &HuggingFaceLoader) -> crate::Result<std::path::PathBuf> {
65        if let Some(local_path) = self.local_path.clone() {
66            let path = presets_path().join(local_path);
67            match std::fs::File::open(&path) {
68                Ok(_) => Ok(path),
69                Err(_) => {
70                    crate::bail!("Failed to open tokenizer_config.json at {}", path.display())
71                }
72            }
73        } else if let (Some(hf_repo), Some(hf_filename)) =
74            (self.hf_repo.clone(), self.hf_filename.clone())
75        {
76            hf_loader.load_file(hf_filename, hf_repo)
77        } else {
78            crate::bail!("No local tokenizer_config.json, or hf_repo and hf_filename provided")
79        }
80    }
81}
82
83impl LLMPresetData {
84    pub fn new<P: AsRef<std::path::Path>>(path: P) -> LLMPresetData {
85        let cargo_manifest_dir = env!("CARGO_MANIFEST_DIR");
86        let path = std::path::PathBuf::from(cargo_manifest_dir)
87            .join("src")
88            .join("local_model")
89            .join("gguf")
90            .join("preset")
91            .join(path)
92            .join("model_macro_data.json");
93        let mut file = std::fs::File::open(&path)
94            .unwrap_or_else(|_| panic!("Failed to open file at {}", path.display()));
95        let mut contents = String::new();
96        std::io::Read::read_to_string(&mut file, &mut contents).expect("Failed to read file");
97        serde_json::from_str(&contents).expect("Failed to parse JSON")
98    }
99}
100
101#[derive(Debug, Clone, serde::Deserialize)]
102pub struct QuantizationConfig {
103    pub q8: Option<String>,
104    pub q7: Option<String>,
105    pub q6: Option<String>,
106    pub q5: Option<String>,
107    pub q4: Option<String>,
108    pub q3: Option<String>,
109    pub q2: Option<String>,
110    pub q1: Option<String>,
111}
112
113macro_rules! generate_models {
114    ($enum_name:ident {
115        $($variant:ident => $path:expr),* $(,)?
116    }) => {
117        #[derive(Debug, Clone)]
118        pub enum $enum_name {
119            $($variant),*
120        }
121
122        impl $enum_name {
123            pub fn get_data(&self) -> &'static LLMPresetData {
124                match self {
125                    $(
126                        Self::$variant => {
127                            static DATA: std::sync::LazyLock<LLMPresetData> = std::sync::LazyLock::new(|| {
128                                LLMPresetData::new($path)
129                            });
130                            &DATA
131                        }
132                    ),*
133                }
134            }
135
136            pub fn model_id(&self) -> String {
137                self.get_data().model_id.to_string()
138            }
139
140            pub fn gguf_repo_id(&self) -> &str {
141                &self.get_data().gguf_repo_id
142            }
143
144            pub fn config_json(&self) -> crate::Result<ConfigJson> {
145                ConfigJson::from_local_path(&self.config_json_path())
146            }
147
148            pub fn f_name_for_q_bits(&self, q_bits: u8) -> Option<String> {
149                match q_bits {
150                    8 => self.get_data().f_name_for_q_bits.q8.clone(),
151                    7 => self.get_data().f_name_for_q_bits.q7.clone(),
152                    6 => self.get_data().f_name_for_q_bits.q6.clone(),
153                    5 => self.get_data().f_name_for_q_bits.q5.clone(),
154                    4 => self.get_data().f_name_for_q_bits.q4.clone(),
155                    3 => self.get_data().f_name_for_q_bits.q3.clone(),
156                    2 => self.get_data().f_name_for_q_bits.q2.clone(),
157                    1 => self.get_data().f_name_for_q_bits.q1.clone(),
158                    _ => panic!("Quantization bits must be between 1 and 8"),
159                }
160
161            }
162
163            pub fn number_of_parameters(&self) -> f64 {
164                self.get_data().number_of_parameters as f64 * 1_000_000_000.0
165            }
166
167
168
169            fn preset_dir_path(&self) -> std::path::PathBuf {
170                match self {
171                    $(
172                        Self::$variant => {
173                            presets_path()
174                                .join($path)
175                        }
176                    ),*
177                }
178            }
179
180            pub fn config_json_path(&self) -> std::path::PathBuf {
181                let preset_config_path = self.preset_dir_path();
182                preset_config_path.join("config.json")
183            }
184
185            pub fn load_tokenizer(&self,hf_loader: &HuggingFaceLoader) -> crate::Result<std::path::PathBuf> {
186                self.get_data().tokenizer_preset_data.load(hf_loader)
187            }
188
189            pub fn load_tokenizer_config(&self,hf_loader: &HuggingFaceLoader) -> crate::Result<std::path::PathBuf> {
190                self.get_data().tokenizer_config_preset_data.load(hf_loader)
191            }
192
193
194
195            pub fn load(&self) -> crate::Result<LocalLLMModel> {
196                let mut loader = GgufLoader::default();
197                loader.gguf_preset_loader.llm_preset = self.clone();
198                loader.load()
199            }
200        }
201
202        pub trait GgufPresetTrait {
203            fn preset_loader(&mut self) -> &mut GgufPresetLoader;
204
205            fn preset_with_available_vram_gb(mut self, preset_with_available_vram_gb: u32) -> Self
206            where
207                Self: Sized,
208            {
209                self.preset_loader().preset_with_available_vram_gb = Some(preset_with_available_vram_gb);
210                self
211            }
212
213
214            fn preset_with_quantization_level(mut self, level: u8) -> Self
215            where
216                Self: Sized,
217            {
218                self.preset_loader().preset_with_quantization_level = Some(level);
219                self
220            }
221
222            $(
223                paste::paste! {
224                    fn [<$variant:snake>](mut self) -> Self
225                    where
226                        Self: Sized,
227                    {
228                        self.preset_loader().llm_preset = $enum_name::$variant;
229                        self
230                    }
231                }
232            )*
233
234        }
235
236
237    };
238}
239
240generate_models!(
241    LLMPreset {
242        SuperNovaMedius13b => "arcee/supernova_medius",
243        Llama3_1_8bInstruct => "llama/llama3_1_8b_instruct",
244        Llama3_2_3bInstruct => "llama/llama3_2_3b_instruct",
245        Llama3_2_1bInstruct => "llama/llama3_2_1b_instruct",
246        Mistral7bInstructV0_3 => "mistral/mistral7b_instruct_v0_3",
247        Mixtral8x7bInstructV0_1 => "mistral/mixtral8x7b_instruct_v0_1",
248        MistralNemoInstruct2407 => "mistral/mistral_nemo_instruct_2407",
249        MistralSmallInstruct2409 => "mistral/mistral_small_instruct_2409",
250        Phi3Medium4kInstruct => "phi/phi3_medium4k_instruct",
251        Phi3Mini4kInstruct => "phi/phi3_mini4k_instruct",
252        Phi3_5MiniInstruct => "phi/phi3_5_mini_instruct",
253        Granite3_8bInstruct => "granite/granite3_8b_instruct",
254        Granite3_2bInstruct => "granite/granite3_2b_instruct",
255        Qwen2_5_32bInstruct => "qwen/qwen2_5_32b_instruct",
256        Qwen2_5_14bInstruct => "qwen/qwen2_5_14b_instruct",
257        Qwen2_5_7bInstruct => "qwen/qwen2_5_7b_instruct",
258        Qwen2_5_3bInstruct => "qwen/qwen2_5_3b_instruct",
259        Llama3_1_70bNemotronInstruct => "nvidia/llama3_1_70b_nemotron_instruct",
260        MistralNemoMinitron8bInstruct => "nvidia/mistral_nemo_minitron_8b_instruct",
261        StableLm2_12bChat => "stabilityai/stablelm_2_12b_chat",
262    }
263);