1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod multimodal_loaders;
5mod normal_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10 fmt::{self, Debug},
11 path::PathBuf,
12 str::FromStr,
13 sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
25 GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
26 MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
27 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
28 Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
29};
30
31pub use multimodal_loaders::{
32 AutoMultimodalLoader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, Idefics2Loader,
33 Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader, Mistral3Loader,
34 MultimodalLoaderType, MultimodalModel, MultimodalModelLoader, Phi3VLoader, Phi4MMLoader,
35 Qwen2VLLoader, Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
36 Qwen3_5MoeLoader, VLlama4Loader, VLlamaLoader, VoxtralLoader,
37};
38
39pub use embedding_loaders::{
40 AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
41 EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
42 Qwen3EmbeddingLoader,
43};
44
45pub use diffusion_loaders::{
46 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
47 DiffusionModelPathsInner, FluxLoader,
48};
49
50use crate::{
51 matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
52 DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
53};
54
55use super::{paths::AdapterPaths, Pipeline};
56
57pub trait ModelPaths: AsAny + Debug + Send + Sync {
60 fn get_weight_filenames(&self) -> &[PathBuf];
62
63 fn get_config_filename(&self) -> &PathBuf;
67
68 fn get_tokenizer_filename(&self) -> &PathBuf;
72
73 fn get_template_filename(&self) -> &Option<PathBuf>;
77
78 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
80
81 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
83
84 fn get_processor_config(&self) -> &Option<PathBuf>;
86
87 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
89
90 fn get_adapter_paths(&self) -> &AdapterPaths;
92
93 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
95}
96
97#[derive(Clone, Debug)]
98pub struct LocalModelPaths<P: Debug> {
100 pub tokenizer_filename: P,
101 pub config_filename: P,
102 pub template_filename: Option<P>,
103 pub filenames: Vec<P>,
104 pub adapter_paths: AdapterPaths,
105 pub gen_conf: Option<P>,
106 pub preprocessor_config: Option<P>,
107 pub processor_config: Option<P>,
108 pub chat_template_json_filename: Option<P>,
109}
110
111impl<P: Debug> LocalModelPaths<P> {
112 #[allow(clippy::too_many_arguments)]
113 pub fn new(
114 tokenizer_filename: P,
115 config_filename: P,
116 template_filename: P,
117 filenames: Vec<P>,
118 adapter_paths: AdapterPaths,
119 gen_conf: Option<P>,
120 preprocessor_config: Option<P>,
121 processor_config: Option<P>,
122 chat_template_json_filename: Option<P>,
123 ) -> Self {
124 Self {
125 tokenizer_filename,
126 config_filename,
127 template_filename: Some(template_filename),
128 filenames,
129 adapter_paths,
130 gen_conf,
131 preprocessor_config,
132 processor_config,
133 chat_template_json_filename,
134 }
135 }
136}
137
138impl ModelPaths for LocalModelPaths<PathBuf> {
139 fn get_config_filename(&self) -> &PathBuf {
140 &self.config_filename
141 }
142 fn get_tokenizer_filename(&self) -> &PathBuf {
143 &self.tokenizer_filename
144 }
145 fn get_weight_filenames(&self) -> &[PathBuf] {
146 &self.filenames
147 }
148 fn get_template_filename(&self) -> &Option<PathBuf> {
149 &self.template_filename
150 }
151 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
152 self.gen_conf.as_ref()
153 }
154 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
155 &self.preprocessor_config
156 }
157 fn get_processor_config(&self) -> &Option<PathBuf> {
158 &self.processor_config
159 }
160 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
161 &self.chat_template_json_filename
162 }
163 fn get_adapter_paths(&self) -> &AdapterPaths {
164 &self.adapter_paths
165 }
166 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
167 None
168 }
169}
170
171#[derive(Clone, Debug)]
172pub struct EmbeddingModelPaths<P: Debug> {
174 pub tokenizer_filename: P,
175 pub config_filename: P,
176 pub modules: Vec<EmbeddingModulePaths>,
177 pub filenames: Vec<P>,
178 pub adapter_paths: AdapterPaths,
179}
180
181impl<P: Debug> EmbeddingModelPaths<P> {
182 #[allow(clippy::too_many_arguments)]
183 pub fn new(
184 tokenizer_filename: P,
185 config_filename: P,
186 filenames: Vec<P>,
187 adapter_paths: AdapterPaths,
188 modules: Vec<EmbeddingModulePaths>,
189 ) -> Self {
190 Self {
191 tokenizer_filename,
192 config_filename,
193 filenames,
194 adapter_paths,
195 modules,
196 }
197 }
198}
199
200impl ModelPaths for EmbeddingModelPaths<PathBuf> {
201 fn get_config_filename(&self) -> &PathBuf {
202 &self.config_filename
203 }
204 fn get_tokenizer_filename(&self) -> &PathBuf {
205 &self.tokenizer_filename
206 }
207 fn get_weight_filenames(&self) -> &[PathBuf] {
208 &self.filenames
209 }
210 fn get_template_filename(&self) -> &Option<PathBuf> {
211 &None
212 }
213 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
214 None
215 }
216 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
217 &None
218 }
219 fn get_processor_config(&self) -> &Option<PathBuf> {
220 &None
221 }
222 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
223 &None
224 }
225 fn get_adapter_paths(&self) -> &AdapterPaths {
226 &self.adapter_paths
227 }
228 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
229 Some(&self.modules)
230 }
231}
232
233#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
234pub enum TokenSource {
236 Literal(String),
237 EnvVar(String),
238 Path(String),
239 CacheToken,
240 None,
241}
242
243impl FromStr for TokenSource {
244 type Err = String;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 let parts: Vec<&str> = s.splitn(2, ':').collect();
248 match parts[0] {
249 "literal" => parts
250 .get(1)
251 .map(|&value| TokenSource::Literal(value.to_string()))
252 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
253 "env" => Ok(TokenSource::EnvVar(
254 parts
255 .get(1)
256 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
257 .to_string(),
258 )),
259 "path" => parts
260 .get(1)
261 .map(|&value| TokenSource::Path(value.to_string()))
262 .ok_or_else(|| "Expected a value for 'path'".to_string()),
263 "cache" => Ok(TokenSource::CacheToken),
264 "none" => Ok(TokenSource::None),
265 _ => Err("Invalid token source format".to_string()),
266 }
267 }
268}
269
270impl fmt::Display for TokenSource {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 match self {
273 TokenSource::Literal(value) => write!(f, "literal:{value}"),
274 TokenSource::EnvVar(value) => write!(f, "env:{value}"),
275 TokenSource::Path(value) => write!(f, "path:{value}"),
276 TokenSource::CacheToken => write!(f, "cache"),
277 TokenSource::None => write!(f, "none"),
278 }
279 }
280}
281
282#[derive(Clone, Default, derive_more::From, strum::Display)]
284pub enum ModelKind {
285 #[default]
286 #[strum(to_string = "normal (no adapters)")]
287 Normal,
288
289 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
290 GgufQuantized { quant: QuantizationKind },
291
292 #[strum(to_string = "{adapter}")]
293 Adapter { adapter: AdapterKind },
294
295 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
296 GgufAdapter {
297 adapter: AdapterKind,
298 quant: QuantizationKind,
299 },
300
301 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
302 Speculative {
303 target: Box<ModelKind>,
304 draft: Box<ModelKind>,
305 },
306
307 #[strum(to_string = "anymoe: target: `{target}`")]
308 AnyMoe { target: Box<ModelKind> },
309}
310
311#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
312#[strum(serialize_all = "kebab-case")]
313pub enum QuantizationKind {
314 Ggml,
316 Gguf,
318 Gptq,
320}
321
322#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
323#[strum(serialize_all = "kebab-case")]
324pub enum AdapterKind {
325 Lora,
327 XLora,
329}
330
331pub trait PrettyName: strum::EnumMessage + ToString {
333 fn pretty_name(&self) -> String {
334 match self.get_documentation() {
335 Some(s) => s.to_string(),
336 None => self.to_string(),
339 }
340 }
341}
342
343impl PrettyName for AdapterKind {}
344impl PrettyName for QuantizationKind {}
345
346impl ModelKind {
347 pub fn is_quantized(&self) -> bool {
349 self.quantized_kind().iter().any(|q| q.is_some())
350 }
351
352 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
353 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
354 }
355
356 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
357 use ModelKind::*;
358
359 match self {
360 Normal | Adapter { .. } => vec![None],
361 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
362 Speculative { target, draft } => {
363 let t = *target.clone();
364 let d = *draft.clone();
365
366 [t.quantized_kind(), d.quantized_kind()].concat()
367 }
368 AnyMoe { target } => target.quantized_kind(),
369 }
370 }
371
372 pub fn is_adapted(&self) -> bool {
374 self.adapted_kind().iter().any(|a| a.is_some())
375 }
376
377 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
378 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
379 }
380
381 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
382 use ModelKind::*;
383
384 match self {
385 Normal | GgufQuantized { .. } => vec![None],
386 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
387 Speculative { target, draft } => {
388 let t = *target.clone();
389 let d = *draft.clone();
390
391 [t.adapted_kind(), d.adapted_kind()].concat()
392 }
393 AnyMoe { target } => target.adapted_kind(),
394 }
395 }
396}
397
398#[derive(Deserialize)]
399pub struct QuantizationConfigShim {
400 quantization_config: Option<QuantizedConfig>,
401}
402
403impl QuantizationConfigShim {
404 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
405 let QuantizationConfigShim {
406 quantization_config,
407 } = serde_json::from_str(config)?;
408
409 if let Some(quantization_config) = quantization_config {
410 Ok(quantization_config.pack_factor(dtype))
411 } else {
412 Ok(1)
413 }
414 }
415}
416
417pub trait DeviceMappedModelLoader {
418 fn non_mapped_max_act_size_elems(
421 &self,
422 config: &str,
423 params: &AutoDeviceMapParams,
424 ) -> Result<usize>;
425 fn mapped_max_act_size_elems(
427 &self,
428 config: &str,
429 params: &AutoDeviceMapParams,
430 ) -> Result<usize>;
431 fn non_mapped_size_in_bytes(
433 &self,
434 config: &str,
435 dtype: DType,
436 weight_pack_factor: usize,
437 matformer_config: Option<&MatformerSliceConfig>,
438 ) -> Result<usize>;
439 fn layer_sizes_in_bytes(
441 &self,
442 config: &str,
443 dtype: DType,
444 weight_pack_factor: usize,
445 matformer_config: Option<&MatformerSliceConfig>,
446 ) -> Result<Vec<usize>>;
447 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
448 None
449 }
450 fn num_layers(&self, config: &str) -> Result<usize>;
451 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
452
453 #[allow(clippy::too_many_arguments)]
454 fn get_device_layers(
455 &self,
456 config: &str,
457 num_layers: usize,
458 layer_sizes_in_bytes: Vec<usize>,
459 non_mapped_size_in_bytes: usize,
460 total_model_size_in_bytes: usize,
461 devices: &[Device],
462 dtype: DType,
463 params: &AutoDeviceMapParams,
464 paged_attn_config: Option<&PagedAttentionConfig>,
465 ) -> Result<DeviceMapMetadata>
466 where
467 Self: Sized,
468 {
469 auto_device_map::get_device_layers(
470 self,
471 config,
472 num_layers,
473 layer_sizes_in_bytes,
474 non_mapped_size_in_bytes,
475 total_model_size_in_bytes,
476 devices,
477 dtype,
478 params,
479 paged_attn_config,
480 )
481 }
482}
483
484pub trait Loader: Send + Sync {
505 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
509 fn load_model_from_hf(
510 &self,
511 revision: Option<String>,
512 token_source: TokenSource,
513 dtype: &dyn TryIntoDType,
514 device: &Device,
515 silent: bool,
516 mapper: DeviceMapSetting,
517 in_situ_quant: Option<IsqType>,
518 paged_attn_config: Option<PagedAttentionConfig>,
519 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
520
521 #[allow(
524 clippy::type_complexity,
525 clippy::too_many_arguments,
526 clippy::borrowed_box
527 )]
528 fn load_model_from_path(
529 &self,
530 paths: &Box<dyn ModelPaths>,
531 dtype: &dyn TryIntoDType,
532 device: &Device,
533 silent: bool,
534 mapper: DeviceMapSetting,
535 in_situ_quant: Option<IsqType>,
536 paged_attn_config: Option<PagedAttentionConfig>,
537 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
538
539 fn get_id(&self) -> String;
540 fn get_kind(&self) -> ModelKind;
541}