Skip to main content

hanzo_engine/pipeline/
multimodal.rs

1use super::isq::{ImatrixDataSource, UqffFullSer, WeightLoadingMode, WeightLoadingState};
2use super::{
3    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoMultimodalLoader,
4    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
5    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
6    ModelKind, ModelPaths, MultimodalModel, MultimodalModelLoader, MultimodalPromptPrefixer,
7    Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader,
8    Qwen3_5Loader, Qwen3_5MoeLoader, TokenSource, VLlama4Loader, VLlamaLoader,
9};
10use super::{
11    Gemma3nLoader, Gemma4Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
12    Mistral3Loader, MultimodalLoaderType, Phi3VLoader, Qwen2_5VLLoader, VoxtralLoader,
13};
14use crate::attention::ATTENTION_CHUNK_SIZE;
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, use_ring, WorkerTransferData};
17use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
18use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
19use crate::pipeline::chat_template::{
20    calculate_eos_tokens, BeginEndUnkPadTok, ChatTemplateValue, GenerationConfig,
21};
22use crate::pipeline::llg::build_llg_factory;
23use crate::pipeline::loaders::auto_device_map;
24use crate::pipeline::loaders::QuantizationConfigShim;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::text_models_inputs_processor::{make_prompt_chunk, InputMetadata};
27use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
28use crate::prefix_cacher::PrefixCacheManagerV2;
29use crate::sequence::Sequence;
30use crate::utils::tokenizer::get_tokenizer;
31use crate::utils::varbuilder_utils::DeviceForLoadTensor;
32use crate::utils::{
33    progress::{new_multi_progress, ProgressScopeGuard},
34    tokens::get_token,
35    varbuilder_utils::from_mmaped_safetensors,
36};
37use crate::vision_models::preprocessor_config::PreProcessorConfig;
38use crate::vision_models::processor_config::ProcessorConfig;
39use crate::vision_models::ModelInputs;
40use crate::{
41    api_dir_list, api_get_file, get_paths, get_uqff_paths, multimodal_normal_model_loader,
42    multimodal_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
43    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
44};
45use anyhow::Result;
46use either::Either;
47use hanzo_ml::{Device, Tensor, Var};
48use hanzo_quant::log::once_log_info;
49use hanzo_quant::{
50    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use hf_hub::Cache;
53use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
54use rand_isaac::Isaac64Rng;
55use regex_automata::meta::Regex;
56use std::any::Any;
57use std::borrow::Cow;
58use std::path::{Path, PathBuf};
59use std::str::FromStr;
60use std::sync::{Arc, RwLock};
61use std::time::Instant;
62use std::{env, fs};
63use tokenizers::Tokenizer;
64use tokio::sync::Mutex;
65use tracing::{debug, info, trace, warn};
66
67pub struct MultimodalPipeline {
68    model: Box<dyn MultimodalModel + Send + Sync>,
69    tokenizer: Arc<Tokenizer>,
70    chat_template: Arc<ChatTemplate>,
71    model_id: String,
72    metadata: Arc<GeneralMetadata>,
73    processor: Arc<dyn Processor + Send + Sync>,
74    preprocessor_config: Arc<PreProcessorConfig>,
75    topology: Option<Topology>,
76    silent: bool,
77    prefixer: Arc<dyn MultimodalPromptPrefixer>,
78    mapper: Box<dyn DeviceMapper + Send + Sync>,
79    organization: IsqOrganization,
80
81    // For full UQFF serialization
82    template_filename: Option<PathBuf>,
83    generation_config: Option<PathBuf>,
84    generation_defaults: Option<crate::ModelGenerationDefaults>,
85    config: String,
86    processor_filename: Option<PathBuf>,
87    preprocessor_filename: Option<PathBuf>,
88    imatrix: Option<PathBuf>,
89}
90
91/// A loader for a multimodal (non-quantized) model.
92pub struct MultimodalLoader {
93    inner: Box<dyn MultimodalModelLoader>,
94    model_id: String,
95    config: MultimodalSpecificConfig,
96    kind: ModelKind,
97    chat_template: Option<String>,
98    tokenizer_json: Option<String>,
99    xlora_model_id: Option<String>,
100    xlora_order: Option<Ordering>,
101    token_source: RwLock<Option<TokenSource>>,
102    revision: RwLock<Option<String>>,
103    from_uqff: RwLock<Option<Vec<PathBuf>>>,
104    jinja_explicit: Option<String>,
105    hf_cache_path: Option<PathBuf>,
106    lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Default)]
110/// A builder for a loader for a multimodal (non-quantized) model.
111pub struct MultimodalLoaderBuilder {
112    model_id: Option<String>,
113    config: MultimodalSpecificConfig,
114    kind: ModelKind,
115    chat_template: Option<String>,
116    tokenizer_json: Option<String>,
117    jinja_explicit: Option<String>,
118    hf_cache_path: Option<PathBuf>,
119    lora_adapter_ids: Option<Vec<String>>,
120}
121
122#[derive(Clone, Default)]
123/// Config specific to loading a multimodal model.
124pub struct MultimodalSpecificConfig {
125    pub topology: Option<Topology>,
126    pub write_uqff: Option<PathBuf>,
127    pub from_uqff: Option<Vec<PathBuf>>,
128    pub max_edge: Option<u32>,
129    pub imatrix: Option<PathBuf>,
130    pub calibration_file: Option<PathBuf>,
131    pub hf_cache_path: Option<PathBuf>,
132    pub matformer_config_path: Option<PathBuf>,
133    pub matformer_slice_name: Option<String>,
134    pub organization: IsqOrganization,
135}
136
137impl MultimodalLoaderBuilder {
138    pub fn new(
139        config: MultimodalSpecificConfig,
140        chat_template: Option<String>,
141        tokenizer_json: Option<String>,
142        model_id: Option<String>,
143        jinja_explicit: Option<String>,
144    ) -> Self {
145        Self {
146            config,
147            chat_template,
148            tokenizer_json,
149            model_id,
150            jinja_explicit,
151            kind: ModelKind::Normal,
152            hf_cache_path: None,
153            ..Default::default()
154        }
155    }
156
157    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
158        self.hf_cache_path = Some(hf_cache_path);
159        self
160    }
161
162    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
163        self.kind = ModelKind::Adapter {
164            adapter: AdapterKind::Lora,
165        };
166        self.lora_adapter_ids = Some(lora_adapter_ids);
167        self
168    }
169
170    pub fn build(self, loader: Option<MultimodalLoaderType>) -> Box<dyn Loader> {
171        let loader: Box<dyn MultimodalModelLoader> = match loader {
172            Some(MultimodalLoaderType::Phi3V) => Box::new(Phi3VLoader),
173            Some(MultimodalLoaderType::Idefics2) => Box::new(Idefics2Loader),
174            Some(MultimodalLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
175            Some(MultimodalLoaderType::LLaVA) => Box::new(LLaVALoader),
176            Some(MultimodalLoaderType::VLlama) => Box::new(VLlamaLoader),
177            Some(MultimodalLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
178            Some(MultimodalLoaderType::Idefics3) => Box::new(Idefics3Loader),
179            Some(MultimodalLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
180            Some(MultimodalLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
181            Some(MultimodalLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
182            Some(MultimodalLoaderType::Gemma3) => Box::new(Gemma3Loader),
183            Some(MultimodalLoaderType::Mistral3) => Box::new(Mistral3Loader),
184            Some(MultimodalLoaderType::Llama4) => Box::new(VLlama4Loader),
185            Some(MultimodalLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
186            Some(MultimodalLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
187            Some(MultimodalLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
188            Some(MultimodalLoaderType::Qwen3_5) => Box::new(Qwen3_5Loader),
189            Some(MultimodalLoaderType::Qwen3_5Moe) => Box::new(Qwen3_5MoeLoader),
190            Some(MultimodalLoaderType::Voxtral) => Box::new(VoxtralLoader),
191            Some(MultimodalLoaderType::Gemma4) => Box::new(Gemma4Loader),
192            None => Box::new(AutoMultimodalLoader),
193        };
194        Box::new(MultimodalLoader {
195            inner: loader,
196            model_id: self.model_id.unwrap(),
197            config: self.config,
198            kind: self.kind,
199            chat_template: self.chat_template,
200            tokenizer_json: self.tokenizer_json,
201            xlora_model_id: None,
202            xlora_order: None,
203            jinja_explicit: self.jinja_explicit,
204            token_source: RwLock::new(None),
205            revision: RwLock::new(None),
206            from_uqff: RwLock::new(None),
207            hf_cache_path: self.hf_cache_path,
208            lora_adapter_ids: self.lora_adapter_ids,
209        })
210    }
211}
212
213impl Loader for MultimodalLoader {
214    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
215    fn load_model_from_hf(
216        &self,
217        revision: Option<String>,
218        token_source: TokenSource,
219        dtype: &dyn TryIntoDType,
220        device: &Device,
221        silent: bool,
222        mapper: DeviceMapSetting,
223        in_situ_quant: Option<IsqType>,
224        paged_attn_config: Option<PagedAttentionConfig>,
225    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
226        let _progress_guard = ProgressScopeGuard::new(silent);
227        let cache = self
228            .hf_cache_path
229            .clone()
230            .map(Cache::new)
231            .unwrap_or_default();
232        GLOBAL_HF_CACHE.get_or_init(|| cache);
233
234        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
235            LocalModelPaths,
236            &token_source,
237            revision.clone(),
238            self,
239            None,
240            None,
241            silent,
242            self.config.from_uqff.is_some()
243        );
244        *self
245            .token_source
246            .write()
247            .expect("Failed to write to token source") = Some(token_source);
248        *self.revision.write().expect("Failed to write to revision") = revision.clone();
249        if let Some(from_uqff) = self.config.from_uqff.clone() {
250            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
251        }
252        self.load_model_from_path(
253            &paths?,
254            dtype,
255            device,
256            silent,
257            mapper,
258            in_situ_quant,
259            paged_attn_config,
260        )
261    }
262
263    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
264    fn load_model_from_path(
265        &self,
266        paths: &Box<dyn ModelPaths>,
267        dtype: &dyn TryIntoDType,
268        device: &Device,
269        silent: bool,
270        mut mapper: DeviceMapSetting,
271        in_situ_quant: Option<IsqType>,
272        mut paged_attn_config: Option<PagedAttentionConfig>,
273    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274        let _progress_guard = ProgressScopeGuard::new(silent);
275        let config = std::fs::read_to_string(paths.get_config_filename())?;
276
277        if !self.inner.supports_paged_attention(&config) {
278            paged_attn_config = None;
279        }
280
281        debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
282
283        let use_nccl = hanzo_quant::distributed::use_nccl();
284
285        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
286            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
287            let WorkerTransferData::Init { id: _, worker_rank } = payload;
288            // Use new_cuda instead of new_cuda_with_stream for NCCL compatibility
289            // NCCL manages its own streams, so explicit stream creation can cause conflicts
290            vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
291        } else if use_nccl || use_ring() {
292            vec![hanzo_ml::Device::new_cuda(0)?]
293        } else {
294            device_map::get_all_similar_devices(device)?
295        };
296        #[cfg(feature = "cuda")]
297        for device in &available_devices {
298            if let Device::Cuda(dev) = device {
299                unsafe { dev.disable_event_tracking() };
300            }
301        }
302        let device = if use_nccl || use_ring() {
303            available_devices[0].clone()
304        } else {
305            device.clone()
306        };
307
308        // Load matformer slicing config if provided
309        let matformer_slicing_config = if let Some(matformer_path) =
310            &self.config.matformer_config_path
311        {
312            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
313            info!("Loading Matformer config from {:?}", matformer_path);
314            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
315
316            if let Some(slice_name) = &self.config.matformer_slice_name {
317                info!("Using Matformer slice: {}", slice_name);
318                Some(MatformerSliceConfig::new(slice_name.clone(), config))
319            } else {
320                // If no slice name is provided but config exists, we'll need to handle this
321                // For now, return None and let the model handle the default slice selection
322                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
323                None
324            }
325        } else {
326            None
327        };
328
329        // If auto, convert to Map if not using nccl
330        let mut max_kv_tokens: Option<usize> = None;
331        if use_nccl || use_ring() {
332            mapper = DeviceMapSetting::DummyNccl {
333                nm_device: available_devices[0].clone(),
334            };
335        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
336            // We can promote to multimodal params if we get text params
337            params = params.maybe_promote_to_multimodal();
338            max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
339
340            // Initial dtype
341            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
342
343            // ISQ or UQFF: quantized path
344            // Match logic below where UQFF has priority
345            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
346                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
347                    let weight_pack_factor = {
348                        let ser_artifacts =
349                            unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
350                        let mut total_pack_factors = 0;
351                        let total_tensors = ser_artifacts.tensors().len();
352                        for (_, artifact) in ser_artifacts.tensors() {
353                            let artifact = artifact.data();
354                            // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
355                            let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
356                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
357                            {
358                                QuantizedSerdeType::Hqq => {
359                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
360                                        .pack_factor(dtype)
361                                }
362                                QuantizedSerdeType::Gguf => {
363                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
364                                        .pack_factor(dtype)
365                                }
366                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
367                                QuantizedSerdeType::Unquant => 1,
368                                QuantizedSerdeType::Afq => {
369                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
370                                        .pack_factor(dtype)
371                                }
372                                QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
373                                QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
374                            };
375                            total_pack_factors += pack_factor;
376                        }
377
378                        total_pack_factors / total_tensors
379                    };
380
381                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
382                        &config,
383                        dtype,
384                        weight_pack_factor,
385                        matformer_slicing_config.as_ref(),
386                    )?;
387                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
388                        &config,
389                        dtype,
390                        weight_pack_factor,
391                        matformer_slicing_config.as_ref(),
392                    )?;
393                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
394                    (
395                        layer_sizes_in_bytes,
396                        non_mapped_size_in_bytes,
397                        layer_sizes_sum + non_mapped_size_in_bytes,
398                    )
399                } else if let Some(isq) = in_situ_quant {
400                    let weight_pack_factor = isq.pack_factor(dtype);
401                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
402                        &config,
403                        dtype,
404                        weight_pack_factor,
405                        matformer_slicing_config.as_ref(),
406                    )?;
407                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
408                        &config,
409                        dtype,
410                        weight_pack_factor,
411                        matformer_slicing_config.as_ref(),
412                    )?;
413                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
414                    (
415                        layer_sizes_in_bytes,
416                        non_mapped_size_in_bytes,
417                        layer_sizes_sum + non_mapped_size_in_bytes,
418                    )
419                } else {
420                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
421                    let weight_pack_factor =
422                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
423                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
424                        &config,
425                        dtype,
426                        weight_pack_factor,
427                        matformer_slicing_config.as_ref(),
428                    )?;
429                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
430                        &config,
431                        dtype,
432                        weight_pack_factor,
433                        matformer_slicing_config.as_ref(),
434                    )?;
435                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
436                    (
437                        layer_sizes_in_bytes,
438                        non_mapped_size_in_bytes,
439                        layer_sizes_sum + non_mapped_size_in_bytes,
440                    )
441                };
442
443            let new = auto_device_map::get_device_layers(
444                &*self.inner,
445                &config,
446                self.inner.num_layers(&config)?,
447                layer_sizes_in_bytes,
448                non_mapped_size_in_bytes,
449                total_model_size_in_bytes,
450                &available_devices,
451                dtype,
452                &params,
453                paged_attn_config.as_ref(),
454            )?;
455            mapper = DeviceMapSetting::Map(new);
456        }
457
458        let pipeline_mapper = mapper.into_mapper(
459            self.inner.num_layers(&config)?,
460            &device,
461            self.config.topology.as_ref(),
462            &available_devices,
463        )?;
464        let mapper = mapper.into_mapper(
465            self.inner.num_layers(&config)?,
466            &device,
467            self.config.topology.as_ref(),
468            &available_devices,
469        )?;
470        let mut layer_devices = Vec::new();
471        for layer in 0..self.inner.num_layers(&config)? {
472            let device = mapper.device_for(layer, false).cloned();
473            layer_devices.push(device);
474        }
475        let dtype = mapper.get_min_dtype(dtype)?;
476
477        // TODO: PagedAttention is not supported with CPU for now.
478        // This check is not really necessary because `get_device_layers` should prevent it.
479        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
480        if mapping_uses_cpu && paged_attn_config.is_some() {
481            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
482            paged_attn_config = None;
483        }
484
485        trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
486        if crate::using_flash_attn() {
487            once_log_info("FlashAttention is enabled.");
488        }
489
490        let topology_overrides = self
491            .config
492            .topology
493            .as_ref()
494            .map(|topology| {
495                topology
496                    .pattern_overrides()
497                    .into_iter()
498                    .map(|(regex, layer)| ImmediateIsqOverride {
499                        predicate: regex,
500                        ty: layer.isq,
501                        device: layer.device.clone(),
502                    })
503                    .collect::<Vec<_>>()
504            })
505            .unwrap_or_default();
506        let has_override_isq = topology_overrides
507            .iter()
508            .any(|override_entry| override_entry.ty.is_some());
509        let topology_requires_post_quant = self
510            .config
511            .topology
512            .as_ref()
513            .is_some_and(|topology| topology.requires_post_quantization());
514
515        let allow_immediate_cli = self.config.imatrix.is_none()
516            && self.config.calibration_file.is_none()
517            && in_situ_quant.is_some();
518
519        let mut immediate_ty = None;
520        let mut immediate_predicates = Vec::new();
521        if allow_immediate_cli {
522            immediate_ty = in_situ_quant;
523            immediate_predicates =
524                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
525                    self.inner.immediate_isq_predicates_moqe(&config)?
526                } else {
527                    self.inner.immediate_isq_predicates(&config)?
528                };
529            info!("Applying ISQ to {in_situ_quant:?}");
530            if immediate_predicates.is_empty() {
531                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
532            }
533        }
534
535        let use_immediate = allow_immediate_cli || has_override_isq;
536        if use_immediate {
537            let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
538            info!("Applying immediate ISQ in parallel on {num_threads} threads.");
539            hanzo_quant::set_immediate_isq_with_pool(
540                immediate_ty,
541                immediate_predicates.clone(),
542                topology_overrides.clone(),
543                pool,
544            );
545        }
546
547        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
548        let mut loading_isq = if use_immediate {
549            false
550        } else {
551            in_situ_quant.is_some()
552        };
553        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
554            loading_isq = true;
555        }
556        loading_isq |= topology_requires_post_quant;
557        loading_isq |= self.config.from_uqff.is_some();
558
559        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
560            anyhow::bail!(
561                "`imatrix` and `calibration_file` were both specified, this is not allowed."
562            );
563        }
564
565        // Load onto the regular device if not using isq or if the calibration file is specified.
566        // For immediate ISQ on discrete GPUs, load to CPU: the mapper will set the correct target
567        // device per-layer, and linear constructors will override to CPU for ISQ-targeted weights.
568        // On integrated/unified memory systems (e.g. Grace Blackwell), CPU and GPU share memory,
569        // so we load directly to the device.
570        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
571            loading_isq = false;
572            if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
573                Device::Cpu
574            } else {
575                device.clone()
576            }
577        } else {
578            Device::Cpu
579        };
580
581        let attention_mechanism = if paged_attn_config.is_some() {
582            AttentionImplementation::PagedAttention
583        } else {
584            AttentionImplementation::Eager
585        };
586
587        let multi_progress = Arc::new(new_multi_progress());
588
589        info!(
590            "{}",
591            WeightLoadingMode::from(WeightLoadingState {
592                from_uqff: self.config.from_uqff.is_some(),
593                loading_isq,
594                immediate_isq: use_immediate,
595                write_uqff: self.config.write_uqff.is_some(),
596            })
597            .message("model")
598        );
599
600        let mut model = if use_nccl || use_ring() {
601            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
602                dtype,
603                &device,
604                &available_devices,
605                silent,
606                &config,
607                loading_isq,
608                self.config.from_uqff.is_some(),
609                self.config.organization,
610                &*self.inner,
611                paths.as_ref(),
612            )?;
613
614            // Special case for where things can be more optimially loaded.
615            match self.kind {
616                ModelKind::Normal => multimodal_normal_model_loader_sharded!(
617                    sharded_vb,
618                    config,
619                    self.inner,
620                    mapper,
621                    loading_isq,
622                    device.clone(),
623                    attention_mechanism,
624                    multi_progress.clone(),
625                    matformer_slicing_config.clone(),
626                ),
627                _ => unreachable!(),
628            }
629        } else {
630            match self.kind {
631                ModelKind::Normal => multimodal_normal_model_loader!(
632                    paths,
633                    Some(dtype),
634                    &load_device,
635                    layer_devices.clone(),
636                    config,
637                    self.inner,
638                    silent,
639                    mapper,
640                    loading_isq,
641                    self.config.from_uqff.is_some(),
642                    device.clone(),
643                    attention_mechanism,
644                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
645                    multi_progress,
646                    matformer_slicing_config.clone(),
647                ),
648                _ => unreachable!(),
649            }
650        };
651
652        let processor_config_json = paths
653            .get_processor_config()
654            .as_ref()
655            .map(|f| fs::read_to_string(f).unwrap());
656
657        // Handle models that only ship processor_config.json with nested
658        // image/audio preprocessor settings and no preprocessor_config.json.
659        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
660        {
661            Some(preprocessor_config) => {
662                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
663            }
664            None => processor_config_json.as_deref().map_or_else(
665                PreProcessorConfig::default,
666                |json| match PreProcessorConfig::from_processor_config_json(json) {
667                    Ok(config) => config,
668                    Err(err) => {
669                        warn!(
670                            "Failed to synthesize preprocessor config from processor_config.json: {err}"
671                        );
672                        PreProcessorConfig::default()
673                    }
674                },
675            ),
676        };
677        let processor_config: Option<ProcessorConfig> = processor_config_json
678            .as_deref()
679            .map(|json| serde_json::from_str(json).unwrap());
680
681        let processor = self.inner.get_processor(
682            &config,
683            processor_config,
684            preprocessor_config.clone(),
685            self.config.max_edge,
686        ); //There are always some repos that don't properly handle config position, for example... LLaVA
687
688        let tokenizer = get_tokenizer(
689            paths.get_tokenizer_filename(),
690            Some(processor.get_special_tokens()),
691        )?;
692
693        let gen_conf: Option<GenerationConfig> = paths
694            .get_gen_conf_filename()
695            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
696        let chat_template_explicit = paths
697            .get_chat_template_explicit()
698            .as_ref()
699            .map(|x| x.to_string_lossy().to_string());
700        let mut chat_template = get_chat_template(
701            paths,
702            self.jinja_explicit.as_ref(),
703            chat_template_explicit.as_ref(),
704            self.chat_template.as_ref(),
705            None,
706        );
707
708        // If no chat template was found, use the loader's built-in default (if any).
709        if chat_template.chat_template.is_none() {
710            if let Some(default_tmpl) = self.inner.default_chat_template(&config) {
711                info!("Using loader's built-in default chat template.");
712                chat_template.chat_template = Some(ChatTemplateValue(Either::Left(default_tmpl)));
713            }
714        }
715
716        // If no bos/eos tokens are set, use the loader's defaults (e.g. for Voxtral
717        // which has no tokenizer_config.json).
718        if let Some((bos, eos)) = self.inner.default_bos_eos(&config) {
719            if chat_template.bos_token.is_none() {
720                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos)));
721            }
722            if chat_template.eos_token.is_none() {
723                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos)));
724            }
725        }
726
727        if let Some(calibration_file) = &self.config.calibration_file {
728            let calibration_data = std::fs::read_to_string(calibration_file)?;
729            // Tokenize, don't add bos yet
730            let tokens = tokenizer
731                .encode_fast(calibration_data, false)
732                .map_err(anyhow::Error::msg)?
733                .get_ids()
734                .to_vec();
735            info!(
736                "Collecting imatrix from calibration file `{}` of {} tokens.",
737                calibration_file.display(),
738                tokens.len()
739            );
740            let bos_tok_id = chat_template
741                .bos_tok()
742                .as_deref()
743                .and_then(|tok| tokenizer.token_to_id(tok));
744
745            // NOTE: We ONLY calibrate the text bits of these models!!
746            // So only those should be tracked!
747            match self.config.organization {
748                IsqOrganization::Default => model.begin_track_stats()?,
749                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
750            }
751
752            const CHUNK_SIZE: usize = 1024;
753            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
754            let start = Instant::now();
755            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
756                let mut chunk = chunk.to_vec();
757                if let Some(bos_tok_id) = bos_tok_id {
758                    chunk.insert(0, bos_tok_id);
759                }
760                let chunk_len = chunk.len();
761
762                let start = Instant::now();
763                let inputs = make_prompt_chunk(
764                    0,
765                    vec![&chunk],
766                    &[0],
767                    &load_device,
768                    None,
769                    false,
770                    None,
771                    None,
772                    None,
773                    model.config().sliding_window,
774                )?;
775                let _ = model.forward(
776                    &inputs.input,
777                    None, // NOTE: We ONLY calibrate the text bits of these models!!
778                    &inputs.positions,
779                    inputs.context_lens,
780                    inputs.position_ids,
781                    model.default_model_specific_args(&inputs.input),
782                    None,
783                    &inputs.flash_meta,
784                )?;
785                match model.cache_mut() {
786                    EitherCache::Full(full) => {
787                        for layer in &mut *full.lock() {
788                            *layer = None
789                        }
790                    }
791                    EitherCache::Normal(normal) => {
792                        for layer in &mut *normal.lock().unwrap().0 {
793                            layer.reset();
794                        }
795                    }
796                    EitherCache::Hybrid(hybrid) => {
797                        hybrid.lock().unwrap().reset();
798                    }
799                }
800                let end = Instant::now();
801                info!(
802                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
803                    i + 1,
804                    end.duration_since(start).as_secs_f32()
805                );
806            }
807            load_device.synchronize()?;
808            let end = Instant::now();
809            info!(
810                "Finished collecting imatrix in {:.2}s",
811                end.duration_since(start).as_secs_f32()
812            );
813        }
814
815        let should_serialize = self.config.write_uqff.is_some();
816        let should_quantize_pass = loading_isq;
817
818        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
819            let imatrix_source = if should_quantize_pass {
820                match (
821                    self.config.imatrix.as_ref(),
822                    self.config.calibration_file.is_some(),
823                ) {
824                    (None, false) => None,
825                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
826                    (None, true) => Some(ImatrixDataSource::Collected),
827                    (Some(_), true) => unreachable!(),
828                }
829            } else {
830                None
831            };
832            if should_quantize_pass {
833                debug!("Applying ISQ to all ranks.");
834            } else {
835                debug!("Serializing existing ISQ tensors without additional quantization.");
836            }
837            model.quantize(
838                in_situ_quant,
839                device.clone(),
840                self.config.topology.as_ref(),
841                silent,
842                imatrix_source,
843                self.config.organization,
844                should_quantize_pass,
845                self.config.write_uqff.as_ref(),
846                UqffFullSer {
847                    tokenizer: &tokenizer,
848                    template_filename: paths.get_template_filename(),
849                    generation_config: paths.get_gen_conf_filename(),
850                    config: config.clone(),
851                    processor_filename: paths.get_processor_config(),
852                    preprocessor_filename: paths.get_preprocessor_config(),
853                    modules: None,
854                    module_paths: None,
855                },
856                Arc::new(new_multi_progress()),
857            )?;
858        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
859            model.load_from_artifacts(
860                device.clone(),
861                self.config.topology.as_ref(),
862                silent,
863                from_uqff,
864            )?;
865        }
866
867        let model_metadata = model.model_config();
868        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
869            anyhow::ensure!(
870                !matches!(self.kind, ModelKind::Adapter { .. }),
871                "PagedAttention does not support adapter models."
872            );
873            let cache_config = calculate_cache_config(
874                paged_attn_config.mem_gpu,
875                paged_attn_config.block_size,
876                dtype,
877                paged_attn_config.cache_type,
878                model_metadata.as_ref(),
879                &device,
880                &layer_devices,
881                silent,
882                None,
883                max_kv_tokens,
884            )?;
885            let cache_engine = CacheEngine::new(
886                model_metadata.as_ref(),
887                &cache_config,
888                dtype,
889                &device,
890                layer_devices,
891            )?;
892            (Some(cache_config), Some(cache_engine))
893        } else {
894            (None, None)
895        };
896
897        let max_seq_len = model.max_seq_len();
898        let llg_factory = build_llg_factory(tokenizer.clone())?;
899        let num_hidden_layers = match model.cache() {
900            EitherCache::Full(full) => full.lock().len(),
901            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
902            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
903        };
904        let generation_defaults = gen_conf
905            .as_ref()
906            .and_then(GenerationConfig::generation_defaults);
907        let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
908        let sliding_window = model.config().sliding_window;
909        Ok(Arc::new(Mutex::new(MultimodalPipeline {
910            model,
911            tokenizer: tokenizer.into(),
912            chat_template: Arc::new(chat_template),
913            model_id: self.model_id.clone(),
914            metadata: Arc::new(GeneralMetadata {
915                max_seq_len,
916                llg_factory: Some(llg_factory),
917                is_xlora: false,
918                num_hidden_layers,
919                eos_tok: eos,
920                kind: self.kind.clone(),
921                no_kv_cache: false,
922                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
923                activation_dtype: dtype,
924                sliding_window,
925                cache_config,
926                cache_engine,
927                model_metadata: Some(model_metadata),
928                modalities: self.inner.modalities(&config)?,
929            }),
930            processor,
931            prefixer: self.inner.prefixer(&config),
932            preprocessor_config: Arc::new(preprocessor_config),
933            topology: self.config.topology.clone(),
934            silent,
935            organization: self.config.organization,
936            template_filename: paths.get_template_filename().clone(),
937            generation_config: paths.get_gen_conf_filename().cloned(),
938            generation_defaults,
939            config,
940            processor_filename: paths.get_processor_config().clone(),
941            preprocessor_filename: paths.get_preprocessor_config().clone(),
942            mapper: pipeline_mapper,
943            imatrix: self.config.imatrix.clone(),
944        })))
945    }
946
947    fn get_id(&self) -> String {
948        self.model_id.to_string()
949    }
950
951    fn get_kind(&self) -> ModelKind {
952        self.kind.clone()
953    }
954}
955
956impl PreProcessingMixin for MultimodalPipeline {
957    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
958        Some(self.chat_template.clone())
959    }
960    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
961        Some(self.preprocessor_config.clone())
962    }
963    fn get_processor(&self) -> Arc<dyn super::Processor> {
964        self.processor.clone()
965    }
966}
967
968impl IsqPipelineMixin for MultimodalPipeline {
969    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
970        let device = self.device().clone();
971        self.model
972            .quantize(
973                Some(dtype),
974                device,
975                self.topology.as_ref(),
976                self.silent,
977                self.imatrix.as_ref().map(ImatrixDataSource::File),
978                self.organization,
979                true,
980                None,
981                UqffFullSer {
982                    tokenizer: &self.tokenizer,
983                    template_filename: &self.template_filename,
984                    generation_config: self.generation_config.as_ref(),
985                    config: self.config.clone(),
986                    processor_filename: &self.processor_filename,
987                    preprocessor_filename: &self.preprocessor_filename,
988                    modules: None,
989                    module_paths: None,
990                },
991                Arc::new(new_multi_progress()),
992            )
993            .map_err(anyhow::Error::msg)
994    }
995}
996
997impl CacheManagerMixin for MultimodalPipeline {
998    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
999        match self.model.cache() {
1000            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1001            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1002            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1003        }
1004    }
1005    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1006        match self.model.cache() {
1007            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1008            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1009            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1010        }
1011    }
1012    fn set_none_cache(
1013        &self,
1014        seqs: &mut [&mut Sequence],
1015        reset_non_granular: bool,
1016        modify_draft_cache: bool,
1017
1018        load_preallocated_cache: bool,
1019    ) {
1020        match self.model.cache() {
1021            EitherCache::Full(_) => {
1022                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1023            }
1024            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1025                self,
1026                seqs,
1027                modify_draft_cache,
1028                load_preallocated_cache,
1029            ),
1030            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1031                self,
1032                seqs,
1033                modify_draft_cache,
1034                load_preallocated_cache,
1035            ),
1036        }
1037        // Always clear model-specific state (e.g. Voxtral audio_embeds_cache)
1038        // for new prompts. set_none_cache is "Only called for prompt seqs",
1039        // so this is always appropriate. Default impl is a no-op.
1040        self.model.reset_model_specific_state();
1041
1042        if reset_non_granular {
1043            self.reset_non_granular_state()
1044        }
1045    }
1046    fn cache(&self) -> &EitherCache {
1047        self.model.cache()
1048    }
1049}
1050
1051impl MetadataMixin for MultimodalPipeline {
1052    fn device(&self) -> Device {
1053        self.model.device().clone()
1054    }
1055    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1056        self.metadata.clone()
1057    }
1058    fn name(&self) -> String {
1059        self.model_id.clone()
1060    }
1061    fn reset_non_granular_state(&self) {
1062        self.model.reset_model_specific_state();
1063    }
1064    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1065        Some(self.tokenizer.clone())
1066    }
1067    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1068        self.generation_defaults.clone()
1069    }
1070    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1071        Some(&*self.mapper)
1072    }
1073}
1074
1075impl crate::speculative::driver::SpeculativePipelineExt for MultimodalPipeline {
1076    fn has_speculative_proposer(&self) -> bool {
1077        self.model.has_speculative_proposer()
1078    }
1079
1080    fn speculative_proposal_len(&self) -> Option<usize> {
1081        self.model.speculative_proposal_len()
1082    }
1083
1084    fn speculative_target_hiddens(
1085        &self,
1086        rows: &[(usize, usize)],
1087    ) -> hanzo_ml::Result<Option<Tensor>> {
1088        self.model.speculative_target_hiddens(rows)
1089    }
1090
1091    fn speculative_propose(
1092        &mut self,
1093        ctx: crate::speculative::SpeculativeProposeBatchCtx<'_>,
1094    ) -> hanzo_ml::Result<Option<crate::speculative::SpeculativeProposalBatch>> {
1095        self.model.speculative_propose(ctx)
1096    }
1097
1098    fn build_speculative_verify_inputs(
1099        &self,
1100        input_meta: InputMetadata,
1101    ) -> hanzo_ml::Result<Box<dyn Any>> {
1102        let model_specific_args = self.model.default_model_specific_args(&input_meta.input);
1103        Ok(Box::new(ModelInputs {
1104            input_ids: input_meta.input,
1105            seqlen_offsets: input_meta.positions,
1106            context_lens: input_meta.context_lens,
1107            position_ids: input_meta.position_ids,
1108            pixel_values: None,
1109            model_specific_args,
1110            paged_attn_meta: input_meta.paged_attn_meta,
1111            flash_meta: input_meta.flash_meta,
1112        }))
1113    }
1114}
1115
1116#[async_trait::async_trait]
1117impl Pipeline for MultimodalPipeline {
1118    fn forward_inputs(
1119        &mut self,
1120        inputs: Box<dyn Any>,
1121        return_raw_logits: bool,
1122    ) -> hanzo_ml::Result<ForwardInputsResult> {
1123        let ModelInputs {
1124            input_ids,
1125            seqlen_offsets,
1126            context_lens,
1127            position_ids,
1128            pixel_values,
1129            model_specific_args,
1130            paged_attn_meta,
1131            flash_meta,
1132        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
1133        let metadata = self.get_metadata();
1134        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1135            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
1136            (Some(_), None) => {
1137                // This can happen if Rust-side user code is wrong
1138                hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1139            }
1140            (None, Some(_)) => {
1141                // This should never happen but we handle it anyway
1142                hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1143            }
1144            (None, None) => None,
1145        };
1146        let logits = self.model.forward(
1147            &input_ids,
1148            pixel_values,
1149            &seqlen_offsets,
1150            context_lens,
1151            position_ids,
1152            model_specific_args,
1153            paged_attn_meta,
1154            &flash_meta,
1155        )?;
1156        if return_raw_logits {
1157            Ok(ForwardInputsResult::RawLogits { logits })
1158        } else {
1159            Ok(ForwardInputsResult::CausalGeneration { logits })
1160        }
1161    }
1162
1163    fn attach_speculative(
1164        &mut self,
1165        config: crate::speculative::SpeculativeConfig,
1166    ) -> hanzo_ml::Result<()> {
1167        if matches!(config, crate::speculative::SpeculativeConfig::Mtp(_))
1168            && self.get_metadata().cache_engine.is_none()
1169        {
1170            hanzo_ml::bail!(
1171                "MTP speculative decoding currently requires PagedAttention for this pipeline."
1172            );
1173        }
1174        if let Some(info) = self.model.attach_speculative(config)? {
1175            self.model.log_speculative_attach(&info);
1176        }
1177        Ok(())
1178    }
1179
1180    #[allow(clippy::too_many_arguments)]
1181    async fn try_sample_speculative_causal_gen(
1182        &mut self,
1183        seqs: &mut [&mut Sequence],
1184        logits: &[Tensor],
1185        prefix_cacher: &mut PrefixCacheManagerV2,
1186        disable_eos_stop: bool,
1187        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1188        metadata: Option<crate::pipeline::text_models_inputs_processor::PagedAttentionMeta>,
1189    ) -> hanzo_ml::Result<bool> {
1190        if !self.model.has_speculative_proposer() {
1191            crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1192            return Ok(false);
1193        }
1194
1195        let general_metadata = self.get_metadata();
1196        if let Some(cache_engine) = general_metadata.cache_engine.as_ref() {
1197            let Some(metadata) = metadata else {
1198                crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1199                return Ok(false);
1200            };
1201            let cache = crate::speculative::cache::PagedSpeculativeCacheAccess::new(
1202                &metadata,
1203                cache_engine,
1204            );
1205            return crate::speculative::driver::try_sample_speculative_causal_gen(
1206                self,
1207                seqs,
1208                logits,
1209                prefix_cacher,
1210                disable_eos_stop,
1211                rng,
1212                &cache,
1213            )
1214            .await;
1215        }
1216
1217        crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1218        Ok(false)
1219    }
1220
1221    async fn sample_causal_gen(
1222        &self,
1223        seqs: &mut [&mut Sequence],
1224        logits: Vec<Tensor>,
1225        prefix_cacher: &mut PrefixCacheManagerV2,
1226        disable_eos_stop: bool,
1227        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1228    ) -> Result<(), hanzo_ml::Error> {
1229        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1230    }
1231    fn category(&self) -> ModelCategory {
1232        ModelCategory::Multimodal {
1233            prefixer: self.prefixer.clone(),
1234        }
1235    }
1236
1237    fn encoder_cache_counters(
1238        &self,
1239    ) -> Option<(
1240        std::sync::Arc<std::sync::atomic::AtomicUsize>,
1241        std::sync::Arc<std::sync::atomic::AtomicUsize>,
1242    )> {
1243        self.model.encoder_cache_counters()
1244    }
1245}
1246
1247impl AnyMoePipelineMixin for MultimodalPipeline {
1248    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
1249        self.model.finish_training(gate_model_id)
1250    }
1251    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1252        self.model.get_vars()
1253    }
1254    fn amoe_base_model_trainable_params(&self) -> usize {
1255        self.model.trainable_params()
1256    }
1257    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1258        self.model.take_cached_gating_outputs()
1259    }
1260    fn amoe_create_layers(
1261        &mut self,
1262        model_ids: Vec<String>,
1263        token: &TokenSource,
1264        revision: Option<String>,
1265        match_regex: &str,
1266        config: crate::amoe::AnyMoeConfig,
1267        dtype: hanzo_ml::DType,
1268        dev: &Device,
1269        (prefix, mlp): (String, String),
1270        layers: Vec<usize>,
1271        expert_type: AnyMoeExpertType,
1272        silent: bool,
1273        gate_model_id: Option<String>,
1274    ) -> hanzo_ml::Result<()> {
1275        let mut vbs = Vec::new();
1276        // Precompile regex here
1277        let regex = Regex::new(match_regex).map_err(hanzo_ml::Error::msg)?;
1278        for model_id in model_ids {
1279            let model_id_str = &model_id;
1280            let model_id = Path::new(&model_id);
1281
1282            let api = {
1283                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1284                let mut api = ApiBuilder::from_cache(cache)
1285                    .with_progress(!silent)
1286                    .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1287                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1288                    api = api.with_cache_dir(cache_dir);
1289                }
1290                api.build().map_err(hanzo_ml::Error::msg)?
1291            };
1292            let revision = revision.clone().unwrap_or("main".to_string());
1293            let api = api.repo(Repo::with_revision(
1294                model_id_str.clone(),
1295                RepoType::Model,
1296                revision.clone(),
1297            ));
1298
1299            let mut filenames = vec![];
1300            for rfilename in api_dir_list!(api, model_id, true, &revision)
1301                .filter(|x| x.ends_with(".safetensors"))
1302            {
1303                filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1304            }
1305
1306            let regex = regex.clone();
1307            let match_regex_clone = match_regex.to_string();
1308            let layers_clone = layers.clone();
1309            let vb = from_mmaped_safetensors(
1310                filenames,
1311                vec![],
1312                Some(dtype),
1313                dev,
1314                vec![None],
1315                silent,
1316                None,
1317                move |key| {
1318                    if regex.is_match(&key) {
1319                        // Idx of the last char of the layer id, +1
1320                        // Assumes N.MLP
1321                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1322                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1323                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1324                            .parse::<usize>()
1325                            .unwrap();
1326                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1327                    } else {
1328                        false
1329                    }
1330                },
1331                Arc::new(|_| DeviceForLoadTensor::Base),
1332            )?;
1333            vbs.push(vb);
1334        }
1335
1336        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1337            let model_id_str = &gate_model_id;
1338            let model_id = Path::new(&gate_model_id);
1339
1340            let api = {
1341                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1342                let mut api = ApiBuilder::from_cache(cache)
1343                    .with_progress(!silent)
1344                    .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1345                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1346                    api = api.with_cache_dir(cache_dir);
1347                }
1348                api.build().map_err(hanzo_ml::Error::msg)?
1349            };
1350            let revision = revision.clone().unwrap_or("main".to_string());
1351            let api = api.repo(Repo::with_revision(
1352                model_id_str.clone(),
1353                RepoType::Model,
1354                revision.clone(),
1355            ));
1356
1357            let mut gate_filenames = vec![];
1358            for rfilename in api_dir_list!(api, model_id, true, &revision)
1359                .filter(|x| x.ends_with(".safetensors"))
1360            {
1361                gate_filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1362            }
1363            assert_eq!(
1364                gate_filenames.len(),
1365                1,
1366                "Gate model ID must contain only one .safetensors file"
1367            );
1368
1369            let vb = from_mmaped_safetensors(
1370                gate_filenames.clone(),
1371                vec![],
1372                Some(dtype),
1373                dev,
1374                vec![None],
1375                silent,
1376                None,
1377                |_| true,
1378                Arc::new(|_| DeviceForLoadTensor::Base),
1379            )?;
1380            info!(
1381                "Loaded gating layers from `{}`",
1382                gate_filenames[0].display()
1383            );
1384            Some(vb)
1385        } else {
1386            None
1387        };
1388
1389        self.model
1390            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1391    }
1392    fn amoe_supported(&self) -> bool {
1393        self.model.amoe_supported()
1394    }
1395}