1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod multimodal_loaders;
5mod normal_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10 fmt::{self, Debug},
11 path::PathBuf,
12 str::FromStr,
13 sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use hanzo_ml::{DType, Device};
19use hanzo_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
25 GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
26 MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
27 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
28 Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
29};
30
31pub use multimodal_loaders::{
32 AutoMultimodalLoader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, Idefics2Loader,
33 Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader, Mistral3Loader,
34 MultimodalLoaderType, MultimodalModel, MultimodalModelLoader, Phi3VLoader, Phi4MMLoader,
35 Qwen2VLLoader, Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
36 Qwen3_5MoeLoader, VLlama4Loader, VLlamaLoader, VoxtralLoader,
37};
38
39pub use embedding_loaders::{
40 AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
41 EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
42 Qwen3EmbeddingLoader,
43};
44
45pub use diffusion_loaders::{
46 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
47 DiffusionModelPathsInner, FluxLoader,
48};
49
50use crate::{
51 matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
52 DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
53};
54
55use super::{paths::AdapterPaths, Pipeline};
56
57pub trait ModelPaths: AsAny + Debug + Send + Sync {
60 fn get_weight_filenames(&self) -> &[PathBuf];
62
63 fn get_config_filename(&self) -> &PathBuf;
67
68 fn get_tokenizer_filename(&self) -> &PathBuf;
72
73 fn get_template_filename(&self) -> &Option<PathBuf>;
77
78 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
80
81 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
83
84 fn get_processor_config(&self) -> &Option<PathBuf>;
86
87 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
89
90 fn get_adapter_paths(&self) -> &AdapterPaths;
92
93 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
95}
96
97#[derive(Clone, Debug)]
98pub struct LocalModelPaths<P: Debug> {
100 pub tokenizer_filename: P,
101 pub config_filename: P,
102 pub template_filename: Option<P>,
103 pub filenames: Vec<P>,
104 pub adapter_paths: AdapterPaths,
105 pub gen_conf: Option<P>,
106 pub preprocessor_config: Option<P>,
107 pub processor_config: Option<P>,
108 pub chat_template_json_filename: Option<P>,
109}
110
111impl<P: Debug> LocalModelPaths<P> {
112 #[allow(clippy::too_many_arguments)]
113 pub fn new(
114 tokenizer_filename: P,
115 config_filename: P,
116 template_filename: P,
117 filenames: Vec<P>,
118 adapter_paths: AdapterPaths,
119 gen_conf: Option<P>,
120 preprocessor_config: Option<P>,
121 processor_config: Option<P>,
122 chat_template_json_filename: Option<P>,
123 ) -> Self {
124 Self {
125 tokenizer_filename,
126 config_filename,
127 template_filename: Some(template_filename),
128 filenames,
129 adapter_paths,
130 gen_conf,
131 preprocessor_config,
132 processor_config,
133 chat_template_json_filename,
134 }
135 }
136}
137
138impl ModelPaths for LocalModelPaths<PathBuf> {
139 fn get_config_filename(&self) -> &PathBuf {
140 &self.config_filename
141 }
142 fn get_tokenizer_filename(&self) -> &PathBuf {
143 &self.tokenizer_filename
144 }
145 fn get_weight_filenames(&self) -> &[PathBuf] {
146 &self.filenames
147 }
148 fn get_template_filename(&self) -> &Option<PathBuf> {
149 &self.template_filename
150 }
151 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
152 self.gen_conf.as_ref()
153 }
154 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
155 &self.preprocessor_config
156 }
157 fn get_processor_config(&self) -> &Option<PathBuf> {
158 &self.processor_config
159 }
160 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
161 &self.chat_template_json_filename
162 }
163 fn get_adapter_paths(&self) -> &AdapterPaths {
164 &self.adapter_paths
165 }
166 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
167 None
168 }
169}
170
171#[derive(Clone, Debug)]
172pub struct EmbeddingModelPaths<P: Debug> {
174 pub tokenizer_filename: P,
175 pub config_filename: P,
176 pub modules: Vec<EmbeddingModulePaths>,
177 pub filenames: Vec<P>,
178 pub adapter_paths: AdapterPaths,
179}
180
181impl<P: Debug> EmbeddingModelPaths<P> {
182 #[allow(clippy::too_many_arguments)]
183 pub fn new(
184 tokenizer_filename: P,
185 config_filename: P,
186 filenames: Vec<P>,
187 adapter_paths: AdapterPaths,
188 modules: Vec<EmbeddingModulePaths>,
189 ) -> Self {
190 Self {
191 tokenizer_filename,
192 config_filename,
193 filenames,
194 adapter_paths,
195 modules,
196 }
197 }
198}
199
200impl ModelPaths for EmbeddingModelPaths<PathBuf> {
201 fn get_config_filename(&self) -> &PathBuf {
202 &self.config_filename
203 }
204 fn get_tokenizer_filename(&self) -> &PathBuf {
205 &self.tokenizer_filename
206 }
207 fn get_weight_filenames(&self) -> &[PathBuf] {
208 &self.filenames
209 }
210 fn get_template_filename(&self) -> &Option<PathBuf> {
211 &None
212 }
213 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
214 None
215 }
216 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
217 &None
218 }
219 fn get_processor_config(&self) -> &Option<PathBuf> {
220 &None
221 }
222 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
223 &None
224 }
225 fn get_adapter_paths(&self) -> &AdapterPaths {
226 &self.adapter_paths
227 }
228 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
229 Some(&self.modules)
230 }
231}
232
233#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
234pub enum TokenSource {
236 Literal(String),
237 EnvVar(String),
238 Path(String),
239 CacheToken,
240 None,
241}
242
243impl FromStr for TokenSource {
244 type Err = String;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 let parts: Vec<&str> = s.splitn(2, ':').collect();
248 match parts[0] {
249 "literal" => parts
250 .get(1)
251 .map(|&value| TokenSource::Literal(value.to_string()))
252 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
253 "env" => Ok(TokenSource::EnvVar(
254 parts
255 .get(1)
256 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
257 .to_string(),
258 )),
259 "path" => parts
260 .get(1)
261 .map(|&value| TokenSource::Path(value.to_string()))
262 .ok_or_else(|| "Expected a value for 'path'".to_string()),
263 "cache" => Ok(TokenSource::CacheToken),
264 "none" => Ok(TokenSource::None),
265 _ => Err("Invalid token source format".to_string()),
266 }
267 }
268}
269
270impl fmt::Display for TokenSource {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 match self {
273 TokenSource::Literal(value) => write!(f, "literal:{value}"),
274 TokenSource::EnvVar(value) => write!(f, "env:{value}"),
275 TokenSource::Path(value) => write!(f, "path:{value}"),
276 TokenSource::CacheToken => write!(f, "cache"),
277 TokenSource::None => write!(f, "none"),
278 }
279 }
280}
281
282#[derive(Clone, Default, derive_more::From, strum::Display)]
284pub enum ModelKind {
285 #[default]
286 #[strum(to_string = "normal (no adapters)")]
287 Normal,
288
289 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
290 GgufQuantized { quant: QuantizationKind },
291
292 #[strum(to_string = "{adapter}")]
293 Adapter { adapter: AdapterKind },
294
295 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
296 GgufAdapter {
297 adapter: AdapterKind,
298 quant: QuantizationKind,
299 },
300
301 #[strum(to_string = "anymoe: target: `{target}`")]
302 AnyMoe { target: Box<ModelKind> },
303}
304
305#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
306#[strum(serialize_all = "kebab-case")]
307pub enum QuantizationKind {
308 Ggml,
310 Gguf,
312 Gptq,
314}
315
316#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
317#[strum(serialize_all = "kebab-case")]
318pub enum AdapterKind {
319 Lora,
321 XLora,
323}
324
325pub trait PrettyName: strum::EnumMessage + ToString {
327 fn pretty_name(&self) -> String {
328 match self.get_documentation() {
329 Some(s) => s.to_string(),
330 None => self.to_string(),
333 }
334 }
335}
336
337impl PrettyName for AdapterKind {}
338impl PrettyName for QuantizationKind {}
339
340impl ModelKind {
341 pub fn is_quantized(&self) -> bool {
343 self.quantized_kind().iter().any(|q| q.is_some())
344 }
345
346 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
347 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
348 }
349
350 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
351 use ModelKind::*;
352
353 match self {
354 Normal | Adapter { .. } => vec![None],
355 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
356 AnyMoe { target } => target.quantized_kind(),
357 }
358 }
359
360 pub fn is_adapted(&self) -> bool {
362 self.adapted_kind().iter().any(|a| a.is_some())
363 }
364
365 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
366 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
367 }
368
369 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
370 use ModelKind::*;
371
372 match self {
373 Normal | GgufQuantized { .. } => vec![None],
374 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
375 AnyMoe { target } => target.adapted_kind(),
376 }
377 }
378}
379
380#[derive(Deserialize)]
381pub struct QuantizationConfigShim {
382 quantization_config: Option<QuantizedConfig>,
383}
384
385impl QuantizationConfigShim {
386 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
387 let QuantizationConfigShim {
388 quantization_config,
389 } = serde_json::from_str(config)?;
390
391 if let Some(quantization_config) = quantization_config {
392 Ok(quantization_config.pack_factor(dtype))
393 } else {
394 Ok(1)
395 }
396 }
397}
398
399pub trait DeviceMappedModelLoader {
400 fn non_mapped_max_act_size_elems(
403 &self,
404 config: &str,
405 params: &AutoDeviceMapParams,
406 ) -> Result<usize>;
407 fn mapped_max_act_size_elems(
409 &self,
410 config: &str,
411 params: &AutoDeviceMapParams,
412 ) -> Result<usize>;
413 fn non_mapped_size_in_bytes(
415 &self,
416 config: &str,
417 dtype: DType,
418 weight_pack_factor: usize,
419 matformer_config: Option<&MatformerSliceConfig>,
420 ) -> Result<usize>;
421 fn layer_sizes_in_bytes(
423 &self,
424 config: &str,
425 dtype: DType,
426 weight_pack_factor: usize,
427 matformer_config: Option<&MatformerSliceConfig>,
428 ) -> Result<Vec<usize>>;
429 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
430 None
431 }
432 fn num_layers(&self, config: &str) -> Result<usize>;
433 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
434
435 #[allow(clippy::too_many_arguments)]
436 fn get_device_layers(
437 &self,
438 config: &str,
439 num_layers: usize,
440 layer_sizes_in_bytes: Vec<usize>,
441 non_mapped_size_in_bytes: usize,
442 total_model_size_in_bytes: usize,
443 devices: &[Device],
444 dtype: DType,
445 params: &AutoDeviceMapParams,
446 paged_attn_config: Option<&PagedAttentionConfig>,
447 ) -> Result<DeviceMapMetadata>
448 where
449 Self: Sized,
450 {
451 auto_device_map::get_device_layers(
452 self,
453 config,
454 num_layers,
455 layer_sizes_in_bytes,
456 non_mapped_size_in_bytes,
457 total_model_size_in_bytes,
458 devices,
459 dtype,
460 params,
461 paged_attn_config,
462 )
463 }
464}
465
466pub trait Loader: Send + Sync {
487 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
491 fn load_model_from_hf(
492 &self,
493 revision: Option<String>,
494 token_source: TokenSource,
495 dtype: &dyn TryIntoDType,
496 device: &Device,
497 silent: bool,
498 mapper: DeviceMapSetting,
499 in_situ_quant: Option<IsqType>,
500 paged_attn_config: Option<PagedAttentionConfig>,
501 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
502
503 #[allow(
506 clippy::type_complexity,
507 clippy::too_many_arguments,
508 clippy::borrowed_box
509 )]
510 fn load_model_from_path(
511 &self,
512 paths: &Box<dyn ModelPaths>,
513 dtype: &dyn TryIntoDType,
514 device: &Device,
515 silent: bool,
516 mapper: DeviceMapSetting,
517 in_situ_quant: Option<IsqType>,
518 paged_attn_config: Option<PagedAttentionConfig>,
519 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
520
521 fn get_id(&self) -> String;
522 fn get_kind(&self) -> ModelKind;
523}