Skip to main content

hanzo_engine/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod multimodal_loaders;
5mod normal_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10    fmt::{self, Debug},
11    path::PathBuf,
12    str::FromStr,
13    sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use hanzo_ml::{DType, Device};
19use hanzo_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
25    GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
26    MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
27    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
28    Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
29};
30
31pub use multimodal_loaders::{
32    AutoMultimodalLoader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, Idefics2Loader,
33    Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader, Mistral3Loader,
34    MultimodalLoaderType, MultimodalModel, MultimodalModelLoader, Phi3VLoader, Phi4MMLoader,
35    Qwen2VLLoader, Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
36    Qwen3_5MoeLoader, VLlama4Loader, VLlamaLoader, VoxtralLoader,
37};
38
39pub use embedding_loaders::{
40    AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
41    EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
42    Qwen3EmbeddingLoader,
43};
44
45pub use diffusion_loaders::{
46    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
47    DiffusionModelPathsInner, FluxLoader,
48};
49
50use crate::{
51    matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
52    DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
53};
54
55use super::{paths::AdapterPaths, Pipeline};
56
57/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
58/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
59pub trait ModelPaths: AsAny + Debug + Send + Sync {
60    /// Model weights files (multiple files supported).
61    fn get_weight_filenames(&self) -> &[PathBuf];
62
63    /// Retrieve the [`PretrainedConfig`] file.
64    ///
65    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
66    fn get_config_filename(&self) -> &PathBuf;
67
68    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
69    ///
70    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
71    fn get_tokenizer_filename(&self) -> &PathBuf;
72
73    /// File where the content is expected to deserialize to [`ChatTemplate`].
74    ///
75    /// [`ChatTemplate`]: crate::ChatTemplate
76    fn get_template_filename(&self) -> &Option<PathBuf>;
77
78    /// Filepath for general model configuration.
79    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
80
81    /// Get the preprocessor config (for the multimodal models). This is used to pre process images.
82    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
83
84    /// Get the processor config (for the multimodal models). This is primarily used for the chat template.
85    fn get_processor_config(&self) -> &Option<PathBuf>;
86
87    /// Get the explicit chat template.
88    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
89
90    /// Get adapter paths.
91    fn get_adapter_paths(&self) -> &AdapterPaths;
92
93    /// Get embedding model `modules.json` compatible with sentence-transformers
94    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
95}
96
97#[derive(Clone, Debug)]
98/// All local paths and metadata necessary to load a model.
99pub struct LocalModelPaths<P: Debug> {
100    pub tokenizer_filename: P,
101    pub config_filename: P,
102    pub template_filename: Option<P>,
103    pub filenames: Vec<P>,
104    pub adapter_paths: AdapterPaths,
105    pub gen_conf: Option<P>,
106    pub preprocessor_config: Option<P>,
107    pub processor_config: Option<P>,
108    pub chat_template_json_filename: Option<P>,
109}
110
111impl<P: Debug> LocalModelPaths<P> {
112    #[allow(clippy::too_many_arguments)]
113    pub fn new(
114        tokenizer_filename: P,
115        config_filename: P,
116        template_filename: P,
117        filenames: Vec<P>,
118        adapter_paths: AdapterPaths,
119        gen_conf: Option<P>,
120        preprocessor_config: Option<P>,
121        processor_config: Option<P>,
122        chat_template_json_filename: Option<P>,
123    ) -> Self {
124        Self {
125            tokenizer_filename,
126            config_filename,
127            template_filename: Some(template_filename),
128            filenames,
129            adapter_paths,
130            gen_conf,
131            preprocessor_config,
132            processor_config,
133            chat_template_json_filename,
134        }
135    }
136}
137
138impl ModelPaths for LocalModelPaths<PathBuf> {
139    fn get_config_filename(&self) -> &PathBuf {
140        &self.config_filename
141    }
142    fn get_tokenizer_filename(&self) -> &PathBuf {
143        &self.tokenizer_filename
144    }
145    fn get_weight_filenames(&self) -> &[PathBuf] {
146        &self.filenames
147    }
148    fn get_template_filename(&self) -> &Option<PathBuf> {
149        &self.template_filename
150    }
151    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
152        self.gen_conf.as_ref()
153    }
154    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
155        &self.preprocessor_config
156    }
157    fn get_processor_config(&self) -> &Option<PathBuf> {
158        &self.processor_config
159    }
160    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
161        &self.chat_template_json_filename
162    }
163    fn get_adapter_paths(&self) -> &AdapterPaths {
164        &self.adapter_paths
165    }
166    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
167        None
168    }
169}
170
171#[derive(Clone, Debug)]
172/// All local paths and metadata necessary to load an embedding model.
173pub struct EmbeddingModelPaths<P: Debug> {
174    pub tokenizer_filename: P,
175    pub config_filename: P,
176    pub modules: Vec<EmbeddingModulePaths>,
177    pub filenames: Vec<P>,
178    pub adapter_paths: AdapterPaths,
179}
180
181impl<P: Debug> EmbeddingModelPaths<P> {
182    #[allow(clippy::too_many_arguments)]
183    pub fn new(
184        tokenizer_filename: P,
185        config_filename: P,
186        filenames: Vec<P>,
187        adapter_paths: AdapterPaths,
188        modules: Vec<EmbeddingModulePaths>,
189    ) -> Self {
190        Self {
191            tokenizer_filename,
192            config_filename,
193            filenames,
194            adapter_paths,
195            modules,
196        }
197    }
198}
199
200impl ModelPaths for EmbeddingModelPaths<PathBuf> {
201    fn get_config_filename(&self) -> &PathBuf {
202        &self.config_filename
203    }
204    fn get_tokenizer_filename(&self) -> &PathBuf {
205        &self.tokenizer_filename
206    }
207    fn get_weight_filenames(&self) -> &[PathBuf] {
208        &self.filenames
209    }
210    fn get_template_filename(&self) -> &Option<PathBuf> {
211        &None
212    }
213    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
214        None
215    }
216    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
217        &None
218    }
219    fn get_processor_config(&self) -> &Option<PathBuf> {
220        &None
221    }
222    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
223        &None
224    }
225    fn get_adapter_paths(&self) -> &AdapterPaths {
226        &self.adapter_paths
227    }
228    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
229        Some(&self.modules)
230    }
231}
232
233#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
234/// The source of the HF token.
235pub enum TokenSource {
236    Literal(String),
237    EnvVar(String),
238    Path(String),
239    CacheToken,
240    None,
241}
242
243impl FromStr for TokenSource {
244    type Err = String;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        let parts: Vec<&str> = s.splitn(2, ':').collect();
248        match parts[0] {
249            "literal" => parts
250                .get(1)
251                .map(|&value| TokenSource::Literal(value.to_string()))
252                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
253            "env" => Ok(TokenSource::EnvVar(
254                parts
255                    .get(1)
256                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
257                    .to_string(),
258            )),
259            "path" => parts
260                .get(1)
261                .map(|&value| TokenSource::Path(value.to_string()))
262                .ok_or_else(|| "Expected a value for 'path'".to_string()),
263            "cache" => Ok(TokenSource::CacheToken),
264            "none" => Ok(TokenSource::None),
265            _ => Err("Invalid token source format".to_string()),
266        }
267    }
268}
269
270impl fmt::Display for TokenSource {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        match self {
273            TokenSource::Literal(value) => write!(f, "literal:{value}"),
274            TokenSource::EnvVar(value) => write!(f, "env:{value}"),
275            TokenSource::Path(value) => write!(f, "path:{value}"),
276            TokenSource::CacheToken => write!(f, "cache"),
277            TokenSource::None => write!(f, "none"),
278        }
279    }
280}
281
282/// The kind of model to build.
283#[derive(Clone, Default, derive_more::From, strum::Display)]
284pub enum ModelKind {
285    #[default]
286    #[strum(to_string = "normal (no adapters)")]
287    Normal,
288
289    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
290    GgufQuantized { quant: QuantizationKind },
291
292    #[strum(to_string = "{adapter}")]
293    Adapter { adapter: AdapterKind },
294
295    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
296    GgufAdapter {
297        adapter: AdapterKind,
298        quant: QuantizationKind,
299    },
300
301    #[strum(to_string = "anymoe: target: `{target}`")]
302    AnyMoe { target: Box<ModelKind> },
303}
304
305#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
306#[strum(serialize_all = "kebab-case")]
307pub enum QuantizationKind {
308    /// GGML
309    Ggml,
310    /// GGUF
311    Gguf,
312    /// GPTQ
313    Gptq,
314}
315
316#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
317#[strum(serialize_all = "kebab-case")]
318pub enum AdapterKind {
319    /// LoRA
320    Lora,
321    /// X-LoRA
322    XLora,
323}
324
325// For the proper name as formatted via doc comment for a variant
326pub trait PrettyName: strum::EnumMessage + ToString {
327    fn pretty_name(&self) -> String {
328        match self.get_documentation() {
329            Some(s) => s.to_string(),
330            // Instead of panic via expect(),
331            // fallback to default kebab-case:
332            None => self.to_string(),
333        }
334    }
335}
336
337impl PrettyName for AdapterKind {}
338impl PrettyName for QuantizationKind {}
339
340impl ModelKind {
341    // Quantized helpers:
342    pub fn is_quantized(&self) -> bool {
343        self.quantized_kind().iter().any(|q| q.is_some())
344    }
345
346    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
347        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
348    }
349
350    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
351        use ModelKind::*;
352
353        match self {
354            Normal | Adapter { .. } => vec![None],
355            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
356            AnyMoe { target } => target.quantized_kind(),
357        }
358    }
359
360    // Adapter helpers:
361    pub fn is_adapted(&self) -> bool {
362        self.adapted_kind().iter().any(|a| a.is_some())
363    }
364
365    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
366        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
367    }
368
369    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
370        use ModelKind::*;
371
372        match self {
373            Normal | GgufQuantized { .. } => vec![None],
374            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
375            AnyMoe { target } => target.adapted_kind(),
376        }
377    }
378}
379
380#[derive(Deserialize)]
381pub struct QuantizationConfigShim {
382    quantization_config: Option<QuantizedConfig>,
383}
384
385impl QuantizationConfigShim {
386    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
387        let QuantizationConfigShim {
388            quantization_config,
389        } = serde_json::from_str(config)?;
390
391        if let Some(quantization_config) = quantization_config {
392            Ok(quantization_config.pack_factor(dtype))
393        } else {
394            Ok(1)
395        }
396    }
397}
398
399pub trait DeviceMappedModelLoader {
400    /// Maximum activation size of non-mapped parts of this model.
401    /// Useful for the multimodal models which may prefer to keep the vison components on the GPU.
402    fn non_mapped_max_act_size_elems(
403        &self,
404        config: &str,
405        params: &AutoDeviceMapParams,
406    ) -> Result<usize>;
407    /// Maximum activation size of mapped parts of the model
408    fn mapped_max_act_size_elems(
409        &self,
410        config: &str,
411        params: &AutoDeviceMapParams,
412    ) -> Result<usize>;
413    /// weight_pack_factor only applies to quantized weights.
414    fn non_mapped_size_in_bytes(
415        &self,
416        config: &str,
417        dtype: DType,
418        weight_pack_factor: usize,
419        matformer_config: Option<&MatformerSliceConfig>,
420    ) -> Result<usize>;
421    /// weight_pack_factor only applies to quantized weights.
422    fn layer_sizes_in_bytes(
423        &self,
424        config: &str,
425        dtype: DType,
426        weight_pack_factor: usize,
427        matformer_config: Option<&MatformerSliceConfig>,
428    ) -> Result<Vec<usize>>;
429    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
430        None
431    }
432    fn num_layers(&self, config: &str) -> Result<usize>;
433    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
434
435    #[allow(clippy::too_many_arguments)]
436    fn get_device_layers(
437        &self,
438        config: &str,
439        num_layers: usize,
440        layer_sizes_in_bytes: Vec<usize>,
441        non_mapped_size_in_bytes: usize,
442        total_model_size_in_bytes: usize,
443        devices: &[Device],
444        dtype: DType,
445        params: &AutoDeviceMapParams,
446        paged_attn_config: Option<&PagedAttentionConfig>,
447    ) -> Result<DeviceMapMetadata>
448    where
449        Self: Sized,
450    {
451        auto_device_map::get_device_layers(
452            self,
453            config,
454            num_layers,
455            layer_sizes_in_bytes,
456            non_mapped_size_in_bytes,
457            total_model_size_in_bytes,
458            devices,
459            dtype,
460            params,
461            paged_attn_config,
462        )
463    }
464}
465
466/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
467/// `load_model` method.
468///
469/// # Example
470/// ```no_run
471/// use hanzo_engine::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
472/// use hanzo_ml::Device;
473///
474/// let loader: Box<dyn Loader> = todo!();
475/// let pipeline = loader.load_model_from_hf(
476///     None,
477///     TokenSource::CacheToken,
478///     &ModelDType::Auto,
479///     &Device::cuda_if_available(0).unwrap(),
480///     false,
481///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
482///     None,
483///     None,
484/// ).unwrap();
485/// ```
486pub trait Loader: Send + Sync {
487    /// If `revision` is None, then it defaults to `main`.
488    /// If `dtype` is None, then it defaults to the model default (usually BF16).
489    /// If model is not found on HF, will attempt to resolve locally.
490    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
491    fn load_model_from_hf(
492        &self,
493        revision: Option<String>,
494        token_source: TokenSource,
495        dtype: &dyn TryIntoDType,
496        device: &Device,
497        silent: bool,
498        mapper: DeviceMapSetting,
499        in_situ_quant: Option<IsqType>,
500        paged_attn_config: Option<PagedAttentionConfig>,
501    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
502
503    /// Load a model from the specified paths.
504    /// Also initializes `DEBUG`.
505    #[allow(
506        clippy::type_complexity,
507        clippy::too_many_arguments,
508        clippy::borrowed_box
509    )]
510    fn load_model_from_path(
511        &self,
512        paths: &Box<dyn ModelPaths>,
513        dtype: &dyn TryIntoDType,
514        device: &Device,
515        silent: bool,
516        mapper: DeviceMapSetting,
517        in_situ_quant: Option<IsqType>,
518        paged_attn_config: Option<PagedAttentionConfig>,
519    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
520
521    fn get_id(&self) -> String;
522    fn get_kind(&self) -> ModelKind;
523}