Skip to main content

mistralrs_core/pipeline/
multimodal.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoMultimodalLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalModel, MultimodalModelLoader, MultimodalPromptPrefixer,
8    Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader,
9    Qwen3_5Loader, Qwen3_5MoeLoader, TokenSource, VLlama4Loader, VLlamaLoader,
10};
11use super::{
12    Gemma3nLoader, Gemma4Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
13    Mistral3Loader, MultimodalLoaderType, Phi3VLoader, Qwen2_5VLLoader, VoxtralLoader,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, use_ring, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{
21    calculate_eos_tokens, BeginEndUnkPadTok, ChatTemplateValue, GenerationConfig,
22};
23use crate::pipeline::llg::build_llg_factory;
24use crate::pipeline::loaders::auto_device_map;
25use crate::pipeline::loaders::QuantizationConfigShim;
26use crate::pipeline::sampling::sample_and_add_toks;
27use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
28use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
29use crate::prefix_cacher::PrefixCacheManagerV2;
30use crate::sequence::Sequence;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::utils::varbuilder_utils::DeviceForLoadTensor;
33use crate::utils::{
34    progress::{new_multi_progress, ProgressScopeGuard},
35    tokens::get_token,
36    varbuilder_utils::from_mmaped_safetensors,
37};
38use crate::vision_models::preprocessor_config::PreProcessorConfig;
39use crate::vision_models::processor_config::ProcessorConfig;
40use crate::vision_models::ModelInputs;
41use crate::{
42    api_dir_list, api_get_file, get_paths, get_uqff_paths, multimodal_normal_model_loader,
43    multimodal_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
44    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
45};
46use anyhow::Result;
47use candle_core::{Device, Tensor, Var};
48use either::Either;
49use hf_hub::Cache;
50use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
51use mistralrs_quant::log::once_log_info;
52use mistralrs_quant::{
53    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
54};
55use rand_isaac::Isaac64Rng;
56use regex_automata::meta::Regex;
57use std::any::Any;
58use std::borrow::Cow;
59use std::path::{Path, PathBuf};
60use std::str::FromStr;
61use std::sync::{Arc, RwLock};
62use std::time::Instant;
63use std::{env, fs};
64use tokenizers::Tokenizer;
65use tokio::sync::Mutex;
66use tracing::{info, warn};
67
68pub struct MultimodalPipeline {
69    model: Box<dyn MultimodalModel + Send + Sync>,
70    tokenizer: Arc<Tokenizer>,
71    chat_template: Arc<ChatTemplate>,
72    model_id: String,
73    metadata: Arc<GeneralMetadata>,
74    processor: Arc<dyn Processor + Send + Sync>,
75    preprocessor_config: Arc<PreProcessorConfig>,
76    topology: Option<Topology>,
77    silent: bool,
78    prefixer: Arc<dyn MultimodalPromptPrefixer>,
79    mapper: Box<dyn DeviceMapper + Send + Sync>,
80    organization: IsqOrganization,
81
82    // For full UQFF serialization
83    template_filename: Option<PathBuf>,
84    generation_config: Option<PathBuf>,
85    generation_defaults: Option<crate::ModelGenerationDefaults>,
86    config: String,
87    processor_filename: Option<PathBuf>,
88    preprocessor_filename: Option<PathBuf>,
89    imatrix: Option<PathBuf>,
90}
91
92/// A loader for a multimodal (non-quantized) model.
93pub struct MultimodalLoader {
94    inner: Box<dyn MultimodalModelLoader>,
95    model_id: String,
96    config: MultimodalSpecificConfig,
97    kind: ModelKind,
98    chat_template: Option<String>,
99    tokenizer_json: Option<String>,
100    xlora_model_id: Option<String>,
101    xlora_order: Option<Ordering>,
102    token_source: RwLock<Option<TokenSource>>,
103    revision: RwLock<Option<String>>,
104    from_uqff: RwLock<Option<Vec<PathBuf>>>,
105    jinja_explicit: Option<String>,
106    hf_cache_path: Option<PathBuf>,
107    lora_adapter_ids: Option<Vec<String>>,
108}
109
110#[derive(Default)]
111/// A builder for a loader for a multimodal (non-quantized) model.
112pub struct MultimodalLoaderBuilder {
113    model_id: Option<String>,
114    config: MultimodalSpecificConfig,
115    kind: ModelKind,
116    chat_template: Option<String>,
117    tokenizer_json: Option<String>,
118    jinja_explicit: Option<String>,
119    hf_cache_path: Option<PathBuf>,
120    lora_adapter_ids: Option<Vec<String>>,
121}
122
123#[derive(Clone, Default)]
124/// Config specific to loading a multimodal model.
125pub struct MultimodalSpecificConfig {
126    pub topology: Option<Topology>,
127    pub write_uqff: Option<PathBuf>,
128    pub from_uqff: Option<Vec<PathBuf>>,
129    pub max_edge: Option<u32>,
130    pub imatrix: Option<PathBuf>,
131    pub calibration_file: Option<PathBuf>,
132    pub hf_cache_path: Option<PathBuf>,
133    pub matformer_config_path: Option<PathBuf>,
134    pub matformer_slice_name: Option<String>,
135    pub organization: IsqOrganization,
136}
137
138impl MultimodalLoaderBuilder {
139    pub fn new(
140        config: MultimodalSpecificConfig,
141        chat_template: Option<String>,
142        tokenizer_json: Option<String>,
143        model_id: Option<String>,
144        jinja_explicit: Option<String>,
145    ) -> Self {
146        Self {
147            config,
148            chat_template,
149            tokenizer_json,
150            model_id,
151            jinja_explicit,
152            kind: ModelKind::Normal,
153            hf_cache_path: None,
154            ..Default::default()
155        }
156    }
157
158    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
159        self.hf_cache_path = Some(hf_cache_path);
160        self
161    }
162
163    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
164        self.kind = ModelKind::Adapter {
165            adapter: AdapterKind::Lora,
166        };
167        self.lora_adapter_ids = Some(lora_adapter_ids);
168        self
169    }
170
171    pub fn build(self, loader: Option<MultimodalLoaderType>) -> Box<dyn Loader> {
172        let loader: Box<dyn MultimodalModelLoader> = match loader {
173            Some(MultimodalLoaderType::Phi3V) => Box::new(Phi3VLoader),
174            Some(MultimodalLoaderType::Idefics2) => Box::new(Idefics2Loader),
175            Some(MultimodalLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
176            Some(MultimodalLoaderType::LLaVA) => Box::new(LLaVALoader),
177            Some(MultimodalLoaderType::VLlama) => Box::new(VLlamaLoader),
178            Some(MultimodalLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
179            Some(MultimodalLoaderType::Idefics3) => Box::new(Idefics3Loader),
180            Some(MultimodalLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
181            Some(MultimodalLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
182            Some(MultimodalLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
183            Some(MultimodalLoaderType::Gemma3) => Box::new(Gemma3Loader),
184            Some(MultimodalLoaderType::Mistral3) => Box::new(Mistral3Loader),
185            Some(MultimodalLoaderType::Llama4) => Box::new(VLlama4Loader),
186            Some(MultimodalLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
187            Some(MultimodalLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
188            Some(MultimodalLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
189            Some(MultimodalLoaderType::Qwen3_5) => Box::new(Qwen3_5Loader),
190            Some(MultimodalLoaderType::Qwen3_5Moe) => Box::new(Qwen3_5MoeLoader),
191            Some(MultimodalLoaderType::Voxtral) => Box::new(VoxtralLoader),
192            Some(MultimodalLoaderType::Gemma4) => Box::new(Gemma4Loader),
193            None => Box::new(AutoMultimodalLoader),
194        };
195        Box::new(MultimodalLoader {
196            inner: loader,
197            model_id: self.model_id.unwrap(),
198            config: self.config,
199            kind: self.kind,
200            chat_template: self.chat_template,
201            tokenizer_json: self.tokenizer_json,
202            xlora_model_id: None,
203            xlora_order: None,
204            jinja_explicit: self.jinja_explicit,
205            token_source: RwLock::new(None),
206            revision: RwLock::new(None),
207            from_uqff: RwLock::new(None),
208            hf_cache_path: self.hf_cache_path,
209            lora_adapter_ids: self.lora_adapter_ids,
210        })
211    }
212}
213
214impl Loader for MultimodalLoader {
215    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
216    fn load_model_from_hf(
217        &self,
218        revision: Option<String>,
219        token_source: TokenSource,
220        dtype: &dyn TryIntoDType,
221        device: &Device,
222        silent: bool,
223        mapper: DeviceMapSetting,
224        in_situ_quant: Option<IsqType>,
225        paged_attn_config: Option<PagedAttentionConfig>,
226    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
227        let _progress_guard = ProgressScopeGuard::new(silent);
228        let cache = self
229            .hf_cache_path
230            .clone()
231            .map(Cache::new)
232            .unwrap_or_default();
233        GLOBAL_HF_CACHE.get_or_init(|| cache);
234
235        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
236            LocalModelPaths,
237            &token_source,
238            revision.clone(),
239            self,
240            None,
241            None,
242            silent,
243            self.config.from_uqff.is_some()
244        );
245        *self
246            .token_source
247            .write()
248            .expect("Failed to write to token source") = Some(token_source);
249        *self.revision.write().expect("Failed to write to revision") = revision.clone();
250        if let Some(from_uqff) = self.config.from_uqff.clone() {
251            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
252        }
253        self.load_model_from_path(
254            &paths?,
255            dtype,
256            device,
257            silent,
258            mapper,
259            in_situ_quant,
260            paged_attn_config,
261        )
262    }
263
264    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
265    fn load_model_from_path(
266        &self,
267        paths: &Box<dyn ModelPaths>,
268        dtype: &dyn TryIntoDType,
269        device: &Device,
270        silent: bool,
271        mut mapper: DeviceMapSetting,
272        in_situ_quant: Option<IsqType>,
273        mut paged_attn_config: Option<PagedAttentionConfig>,
274    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
275        let _progress_guard = ProgressScopeGuard::new(silent);
276        let config = std::fs::read_to_string(paths.get_config_filename())?;
277
278        if !self.inner.supports_paged_attention(&config) {
279            paged_attn_config = None;
280        }
281
282        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
283
284        let use_nccl = mistralrs_quant::distributed::use_nccl();
285
286        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
287            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
288            let WorkerTransferData::Init { id: _, worker_rank } = payload;
289            // Use new_cuda instead of new_cuda_with_stream for NCCL compatibility
290            // NCCL manages its own streams, so explicit stream creation can cause conflicts
291            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
292        } else if use_nccl || use_ring() {
293            vec![candle_core::Device::new_cuda(0)?]
294        } else {
295            device_map::get_all_similar_devices(device)?
296        };
297        #[cfg(feature = "cuda")]
298        for device in &available_devices {
299            if let Device::Cuda(dev) = device {
300                unsafe { dev.disable_event_tracking() };
301            }
302        }
303        let device = if use_nccl || use_ring() {
304            available_devices[0].clone()
305        } else {
306            device.clone()
307        };
308
309        // Load matformer slicing config if provided
310        let matformer_slicing_config = if let Some(matformer_path) =
311            &self.config.matformer_config_path
312        {
313            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
314            info!("Loading Matformer config from {:?}", matformer_path);
315            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
316
317            if let Some(slice_name) = &self.config.matformer_slice_name {
318                info!("Using Matformer slice: {}", slice_name);
319                Some(MatformerSliceConfig::new(slice_name.clone(), config))
320            } else {
321                // If no slice name is provided but config exists, we'll need to handle this
322                // For now, return None and let the model handle the default slice selection
323                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
324                None
325            }
326        } else {
327            None
328        };
329
330        // If auto, convert to Map if not using nccl
331        let mut max_kv_tokens: Option<usize> = None;
332        if use_nccl || use_ring() {
333            mapper = DeviceMapSetting::DummyNccl {
334                nm_device: available_devices[0].clone(),
335            };
336        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
337            // We can promote to multimodal params if we get text params
338            params = params.maybe_promote_to_multimodal();
339            max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
340
341            // Initial dtype
342            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
343
344            // ISQ or UQFF: quantized path
345            // Match logic below where UQFF has priority
346            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
347                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
348                    let weight_pack_factor = {
349                        let ser_artifacts = unsafe {
350                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
351                        };
352                        let mut total_pack_factors = 0;
353                        let total_tensors = ser_artifacts.tensors().len();
354                        for (_, artifact) in ser_artifacts.tensors() {
355                            let artifact = artifact.data();
356                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
357                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
358                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
359                            {
360                                QuantizedSerdeType::Hqq => {
361                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
362                                        .pack_factor(dtype)
363                                }
364                                QuantizedSerdeType::Gguf => {
365                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
366                                        .pack_factor(dtype)
367                                }
368                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
369                                QuantizedSerdeType::Unquant => 1,
370                                QuantizedSerdeType::Afq => {
371                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
372                                        .pack_factor(dtype)
373                                }
374                                QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
375                                QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
376                            };
377                            total_pack_factors += pack_factor;
378                        }
379
380                        total_pack_factors / total_tensors
381                    };
382
383                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
384                        &config,
385                        dtype,
386                        weight_pack_factor,
387                        matformer_slicing_config.as_ref(),
388                    )?;
389                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
390                        &config,
391                        dtype,
392                        weight_pack_factor,
393                        matformer_slicing_config.as_ref(),
394                    )?;
395                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
396                    (
397                        layer_sizes_in_bytes,
398                        non_mapped_size_in_bytes,
399                        layer_sizes_sum + non_mapped_size_in_bytes,
400                    )
401                } else if let Some(isq) = in_situ_quant {
402                    let weight_pack_factor = isq.pack_factor(dtype);
403                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
404                        &config,
405                        dtype,
406                        weight_pack_factor,
407                        matformer_slicing_config.as_ref(),
408                    )?;
409                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
410                        &config,
411                        dtype,
412                        weight_pack_factor,
413                        matformer_slicing_config.as_ref(),
414                    )?;
415                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
416                    (
417                        layer_sizes_in_bytes,
418                        non_mapped_size_in_bytes,
419                        layer_sizes_sum + non_mapped_size_in_bytes,
420                    )
421                } else {
422                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
423                    let weight_pack_factor =
424                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
425                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
426                        &config,
427                        dtype,
428                        weight_pack_factor,
429                        matformer_slicing_config.as_ref(),
430                    )?;
431                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
432                        &config,
433                        dtype,
434                        weight_pack_factor,
435                        matformer_slicing_config.as_ref(),
436                    )?;
437                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
438                    (
439                        layer_sizes_in_bytes,
440                        non_mapped_size_in_bytes,
441                        layer_sizes_sum + non_mapped_size_in_bytes,
442                    )
443                };
444
445            let new = auto_device_map::get_device_layers(
446                &*self.inner,
447                &config,
448                self.inner.num_layers(&config)?,
449                layer_sizes_in_bytes,
450                non_mapped_size_in_bytes,
451                total_model_size_in_bytes,
452                &available_devices,
453                dtype,
454                &params,
455                paged_attn_config.as_ref(),
456            )?;
457            mapper = DeviceMapSetting::Map(new);
458        }
459
460        let pipeline_mapper = mapper.into_mapper(
461            self.inner.num_layers(&config)?,
462            &device,
463            self.config.topology.as_ref(),
464            &available_devices,
465        )?;
466        let mapper = mapper.into_mapper(
467            self.inner.num_layers(&config)?,
468            &device,
469            self.config.topology.as_ref(),
470            &available_devices,
471        )?;
472        let mut layer_devices = Vec::new();
473        for layer in 0..self.inner.num_layers(&config)? {
474            let device = mapper.device_for(layer, false).cloned();
475            layer_devices.push(device);
476        }
477        let dtype = mapper.get_min_dtype(dtype)?;
478
479        // TODO: PagedAttention is not supported with CPU for now.
480        // This check is not really necessary because `get_device_layers` should prevent it.
481        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
482        if mapping_uses_cpu && paged_attn_config.is_some() {
483            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
484            paged_attn_config = None;
485        }
486
487        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
488        if crate::using_flash_attn() {
489            once_log_info("FlashAttention is enabled.");
490        }
491
492        let topology_overrides = self
493            .config
494            .topology
495            .as_ref()
496            .map(|topology| {
497                topology
498                    .pattern_overrides()
499                    .into_iter()
500                    .map(|(regex, layer)| ImmediateIsqOverride {
501                        predicate: regex,
502                        ty: layer.isq,
503                        device: layer.device.clone(),
504                    })
505                    .collect::<Vec<_>>()
506            })
507            .unwrap_or_default();
508        let has_override_isq = topology_overrides
509            .iter()
510            .any(|override_entry| override_entry.ty.is_some());
511        let topology_requires_post_quant = self
512            .config
513            .topology
514            .as_ref()
515            .is_some_and(|topology| topology.requires_post_quantization());
516
517        let allow_immediate_cli = self.config.imatrix.is_none()
518            && self.config.calibration_file.is_none()
519            && in_situ_quant.is_some();
520
521        let mut immediate_ty = None;
522        let mut immediate_predicates = Vec::new();
523        if allow_immediate_cli {
524            immediate_ty = in_situ_quant;
525            immediate_predicates =
526                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
527                    self.inner.immediate_isq_predicates_moqe(&config)?
528                } else {
529                    self.inner.immediate_isq_predicates(&config)?
530                };
531            info!("Applying ISQ to {in_situ_quant:?}");
532            if immediate_predicates.is_empty() {
533                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
534            }
535        }
536
537        let use_immediate = allow_immediate_cli || has_override_isq;
538        if use_immediate {
539            let (pool, num_threads) = mistralrs_quant::create_isq_thread_pool(immediate_ty);
540            info!("Applying immediate ISQ in parallel on {num_threads} threads.");
541            mistralrs_quant::set_immediate_isq_with_pool(
542                immediate_ty,
543                immediate_predicates.clone(),
544                topology_overrides.clone(),
545                pool,
546            );
547        }
548
549        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
550        let mut loading_isq = if use_immediate {
551            false
552        } else {
553            in_situ_quant.is_some()
554        };
555        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
556            loading_isq = true;
557        }
558        loading_isq |= topology_requires_post_quant;
559        loading_isq |= self.config.from_uqff.is_some();
560
561        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
562            anyhow::bail!(
563                "`imatrix` and `calibration_file` were both specified, this is not allowed."
564            );
565        }
566
567        // Load onto the regular device if not using isq or if the calibration file is specified.
568        // For immediate ISQ on discrete GPUs, load to CPU: the mapper will set the correct target
569        // device per-layer, and linear constructors will override to CPU for ISQ-targeted weights.
570        // On integrated/unified memory systems (e.g. Grace Blackwell), CPU and GPU share memory,
571        // so we load directly to the device.
572        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
573            loading_isq = false;
574            if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
575                Device::Cpu
576            } else {
577                device.clone()
578            }
579        } else {
580            Device::Cpu
581        };
582
583        let attention_mechanism = if paged_attn_config.is_some() {
584            AttentionImplementation::PagedAttention
585        } else {
586            AttentionImplementation::Eager
587        };
588
589        let multi_progress = Arc::new(new_multi_progress());
590
591        let mut model = if use_nccl || use_ring() {
592            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
593                dtype,
594                &device,
595                &available_devices,
596                silent,
597                &config,
598                loading_isq,
599                self.config.from_uqff.is_some(),
600                self.config.organization,
601                &*self.inner,
602                paths.as_ref(),
603            )?;
604
605            // Special case for where things can be more optimially loaded.
606            match self.kind {
607                ModelKind::Normal => multimodal_normal_model_loader_sharded!(
608                    sharded_vb,
609                    config,
610                    self.inner,
611                    mapper,
612                    loading_isq,
613                    device.clone(),
614                    attention_mechanism,
615                    multi_progress.clone(),
616                    matformer_slicing_config.clone(),
617                ),
618                _ => unreachable!(),
619            }
620        } else {
621            match self.kind {
622                ModelKind::Normal => multimodal_normal_model_loader!(
623                    paths,
624                    Some(dtype),
625                    &load_device,
626                    layer_devices.clone(),
627                    config,
628                    self.inner,
629                    silent,
630                    mapper,
631                    loading_isq,
632                    self.config.from_uqff.is_some(),
633                    device.clone(),
634                    attention_mechanism,
635                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
636                    multi_progress,
637                    matformer_slicing_config.clone(),
638                ),
639                _ => unreachable!(),
640            }
641        };
642
643        let processor_config_json = paths
644            .get_processor_config()
645            .as_ref()
646            .map(|f| fs::read_to_string(f).unwrap());
647
648        // Handle models that only ship processor_config.json with nested
649        // image/audio preprocessor settings and no preprocessor_config.json.
650        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
651        {
652            Some(preprocessor_config) => {
653                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
654            }
655            None => processor_config_json.as_deref().map_or_else(
656                PreProcessorConfig::default,
657                |json| match PreProcessorConfig::from_processor_config_json(json) {
658                    Ok(config) => config,
659                    Err(err) => {
660                        warn!(
661                            "Failed to synthesize preprocessor config from processor_config.json: {err}"
662                        );
663                        PreProcessorConfig::default()
664                    }
665                },
666            ),
667        };
668        let processor_config: Option<ProcessorConfig> = processor_config_json
669            .as_deref()
670            .map(|json| serde_json::from_str(json).unwrap());
671
672        let processor = self.inner.get_processor(
673            &config,
674            processor_config,
675            preprocessor_config.clone(),
676            self.config.max_edge,
677        ); //There are always some repos that don't properly handle config position, for example... LLaVA
678
679        let tokenizer = get_tokenizer(
680            paths.get_tokenizer_filename(),
681            Some(processor.get_special_tokens()),
682        )?;
683
684        let gen_conf: Option<GenerationConfig> = paths
685            .get_gen_conf_filename()
686            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
687        let chat_template_explicit = paths
688            .get_chat_template_explicit()
689            .as_ref()
690            .map(|x| x.to_string_lossy().to_string());
691        let mut chat_template = get_chat_template(
692            paths,
693            self.jinja_explicit.as_ref(),
694            chat_template_explicit.as_ref(),
695            self.chat_template.as_ref(),
696            None,
697        );
698
699        // If no chat template was found, use the loader's built-in default (if any).
700        if chat_template.chat_template.is_none() {
701            if let Some(default_tmpl) = self.inner.default_chat_template(&config) {
702                info!("Using loader's built-in default chat template.");
703                chat_template.chat_template = Some(ChatTemplateValue(Either::Left(default_tmpl)));
704            }
705        }
706
707        // If no bos/eos tokens are set, use the loader's defaults (e.g. for Voxtral
708        // which has no tokenizer_config.json).
709        if let Some((bos, eos)) = self.inner.default_bos_eos(&config) {
710            if chat_template.bos_token.is_none() {
711                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos)));
712            }
713            if chat_template.eos_token.is_none() {
714                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos)));
715            }
716        }
717
718        if let Some(calibration_file) = &self.config.calibration_file {
719            let calibration_data = std::fs::read_to_string(calibration_file)?;
720            // Tokenize, don't add bos yet
721            let tokens = tokenizer
722                .encode_fast(calibration_data, false)
723                .map_err(anyhow::Error::msg)?
724                .get_ids()
725                .to_vec();
726            info!(
727                "Collecting imatrix from calibration file `{}` of {} tokens.",
728                calibration_file.display(),
729                tokens.len()
730            );
731            let bos_tok_id = chat_template
732                .bos_tok()
733                .as_deref()
734                .and_then(|tok| tokenizer.token_to_id(tok));
735
736            // NOTE: We ONLY calibrate the text bits of these models!!
737            // So only those should be tracked!
738            match self.config.organization {
739                IsqOrganization::Default => model.begin_track_stats()?,
740                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
741            }
742
743            const CHUNK_SIZE: usize = 1024;
744            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
745            let start = Instant::now();
746            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
747                let mut chunk = chunk.to_vec();
748                if let Some(bos_tok_id) = bos_tok_id {
749                    chunk.insert(0, bos_tok_id);
750                }
751                let chunk_len = chunk.len();
752
753                let start = Instant::now();
754                let inputs = make_prompt_chunk(
755                    0,
756                    vec![&chunk],
757                    &[0],
758                    &load_device,
759                    None,
760                    false,
761                    None,
762                    None,
763                    None,
764                    model.config().sliding_window,
765                )?;
766                let _ = model.forward(
767                    &inputs.input,
768                    None, // NOTE: We ONLY calibrate the text bits of these models!!
769                    &inputs.positions,
770                    inputs.context_lens,
771                    inputs.position_ids,
772                    model.default_model_specific_args(&inputs.input),
773                    None,
774                    &inputs.flash_meta,
775                )?;
776                match model.cache_mut() {
777                    EitherCache::Full(full) => {
778                        for layer in &mut *full.lock() {
779                            *layer = None
780                        }
781                    }
782                    EitherCache::Normal(normal) => {
783                        for layer in &mut *normal.lock().unwrap().0 {
784                            layer.reset();
785                        }
786                    }
787                    EitherCache::Hybrid(hybrid) => {
788                        hybrid.lock().unwrap().reset();
789                    }
790                }
791                let end = Instant::now();
792                info!(
793                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
794                    i + 1,
795                    end.duration_since(start).as_secs_f32()
796                );
797            }
798            load_device.synchronize()?;
799            let end = Instant::now();
800            info!(
801                "Finished collecting imatrix in {:.2}s",
802                end.duration_since(start).as_secs_f32()
803            );
804        }
805
806        let should_serialize = self.config.write_uqff.is_some();
807        let should_quantize_pass = loading_isq;
808
809        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
810            let imatrix_source = if should_quantize_pass {
811                match (
812                    self.config.imatrix.as_ref(),
813                    self.config.calibration_file.is_some(),
814                ) {
815                    (None, false) => None,
816                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
817                    (None, true) => Some(ImatrixDataSource::Collected),
818                    (Some(_), true) => unreachable!(),
819                }
820            } else {
821                None
822            };
823            if should_quantize_pass {
824                info!("Applying ISQ to all ranks.");
825            } else {
826                info!("Serializing existing ISQ tensors without additional quantization.");
827            }
828            model.quantize(
829                in_situ_quant,
830                device.clone(),
831                self.config.topology.as_ref(),
832                silent,
833                imatrix_source,
834                self.config.organization,
835                should_quantize_pass,
836                self.config.write_uqff.as_ref(),
837                UqffFullSer {
838                    tokenizer: &tokenizer,
839                    template_filename: paths.get_template_filename(),
840                    generation_config: paths.get_gen_conf_filename(),
841                    config: config.clone(),
842                    processor_filename: paths.get_processor_config(),
843                    preprocessor_filename: paths.get_preprocessor_config(),
844                    modules: None,
845                    module_paths: None,
846                },
847                Arc::new(new_multi_progress()),
848            )?;
849        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
850            model.load_from_artifacts(
851                device.clone(),
852                self.config.topology.as_ref(),
853                silent,
854                from_uqff,
855            )?;
856        }
857
858        let model_metadata = model.model_config();
859        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
860            anyhow::ensure!(
861                !matches!(self.kind, ModelKind::Adapter { .. }),
862                "PagedAttention does not support adapter models."
863            );
864            let cache_config = calculate_cache_config(
865                paged_attn_config.mem_gpu,
866                paged_attn_config.block_size,
867                dtype,
868                paged_attn_config.cache_type,
869                model_metadata.as_ref(),
870                &device,
871                &layer_devices,
872                silent,
873                None,
874                max_kv_tokens,
875            )?;
876            let cache_engine = CacheEngine::new(
877                model_metadata.as_ref(),
878                &cache_config,
879                dtype,
880                &device,
881                layer_devices,
882            )?;
883            (Some(cache_config), Some(cache_engine))
884        } else {
885            (None, None)
886        };
887
888        let max_seq_len = model.max_seq_len();
889        let llg_factory = build_llg_factory(tokenizer.clone())?;
890        let num_hidden_layers = match model.cache() {
891            EitherCache::Full(full) => full.lock().len(),
892            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
893            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
894        };
895        let generation_defaults = gen_conf
896            .as_ref()
897            .and_then(GenerationConfig::generation_defaults);
898        let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
899        let sliding_window = model.config().sliding_window;
900        Ok(Arc::new(Mutex::new(MultimodalPipeline {
901            model,
902            tokenizer: tokenizer.into(),
903            chat_template: Arc::new(chat_template),
904            model_id: self.model_id.clone(),
905            metadata: Arc::new(GeneralMetadata {
906                max_seq_len,
907                llg_factory: Some(llg_factory),
908                is_xlora: false,
909                num_hidden_layers,
910                eos_tok: eos,
911                kind: self.kind.clone(),
912                no_kv_cache: false,
913                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
914                activation_dtype: dtype,
915                sliding_window,
916                cache_config,
917                cache_engine,
918                model_metadata: Some(model_metadata),
919                modalities: self.inner.modalities(&config)?,
920            }),
921            processor,
922            prefixer: self.inner.prefixer(&config),
923            preprocessor_config: Arc::new(preprocessor_config),
924            topology: self.config.topology.clone(),
925            silent,
926            organization: self.config.organization,
927            template_filename: paths.get_template_filename().clone(),
928            generation_config: paths.get_gen_conf_filename().cloned(),
929            generation_defaults,
930            config,
931            processor_filename: paths.get_processor_config().clone(),
932            preprocessor_filename: paths.get_preprocessor_config().clone(),
933            mapper: pipeline_mapper,
934            imatrix: self.config.imatrix.clone(),
935        })))
936    }
937
938    fn get_id(&self) -> String {
939        self.model_id.to_string()
940    }
941
942    fn get_kind(&self) -> ModelKind {
943        self.kind.clone()
944    }
945}
946
947impl PreProcessingMixin for MultimodalPipeline {
948    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
949        Some(self.chat_template.clone())
950    }
951    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
952        Some(self.preprocessor_config.clone())
953    }
954    fn get_processor(&self) -> Arc<dyn super::Processor> {
955        self.processor.clone()
956    }
957}
958
959impl IsqPipelineMixin for MultimodalPipeline {
960    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
961        let device = self.device().clone();
962        self.model
963            .quantize(
964                Some(dtype),
965                device,
966                self.topology.as_ref(),
967                self.silent,
968                self.imatrix.as_ref().map(ImatrixDataSource::File),
969                self.organization,
970                true,
971                None,
972                UqffFullSer {
973                    tokenizer: &self.tokenizer,
974                    template_filename: &self.template_filename,
975                    generation_config: self.generation_config.as_ref(),
976                    config: self.config.clone(),
977                    processor_filename: &self.processor_filename,
978                    preprocessor_filename: &self.preprocessor_filename,
979                    modules: None,
980                    module_paths: None,
981                },
982                Arc::new(new_multi_progress()),
983            )
984            .map_err(anyhow::Error::msg)
985    }
986}
987
988impl CacheManagerMixin for MultimodalPipeline {
989    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
990        match self.model.cache() {
991            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
992            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
993            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
994        }
995    }
996    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
997        match self.model.cache() {
998            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
999            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1000            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1001        }
1002    }
1003    fn set_none_cache(
1004        &self,
1005        seqs: &mut [&mut Sequence],
1006        reset_non_granular: bool,
1007        modify_draft_cache: bool,
1008
1009        load_preallocated_cache: bool,
1010    ) {
1011        match self.model.cache() {
1012            EitherCache::Full(_) => {
1013                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1014            }
1015            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1016                self,
1017                seqs,
1018                modify_draft_cache,
1019                load_preallocated_cache,
1020            ),
1021            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1022                self,
1023                seqs,
1024                modify_draft_cache,
1025                load_preallocated_cache,
1026            ),
1027        }
1028        // Always clear model-specific state (e.g. Voxtral audio_embeds_cache)
1029        // for new prompts. set_none_cache is "Only called for prompt seqs",
1030        // so this is always appropriate. Default impl is a no-op.
1031        self.model.reset_model_specific_state();
1032
1033        if reset_non_granular {
1034            self.reset_non_granular_state()
1035        }
1036    }
1037    fn cache(&self) -> &EitherCache {
1038        self.model.cache()
1039    }
1040}
1041
1042impl MetadataMixin for MultimodalPipeline {
1043    fn device(&self) -> Device {
1044        self.model.device().clone()
1045    }
1046    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1047        self.metadata.clone()
1048    }
1049    fn name(&self) -> String {
1050        self.model_id.clone()
1051    }
1052    fn reset_non_granular_state(&self) {
1053        self.model.reset_model_specific_state();
1054    }
1055    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1056        Some(self.tokenizer.clone())
1057    }
1058    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1059        self.generation_defaults.clone()
1060    }
1061    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1062        Some(&*self.mapper)
1063    }
1064}
1065
1066#[async_trait::async_trait]
1067impl Pipeline for MultimodalPipeline {
1068    fn forward_inputs(
1069        &mut self,
1070        inputs: Box<dyn Any>,
1071        return_raw_logits: bool,
1072    ) -> candle_core::Result<ForwardInputsResult> {
1073        let ModelInputs {
1074            input_ids,
1075            seqlen_offsets,
1076            context_lens,
1077            position_ids,
1078            pixel_values,
1079            model_specific_args,
1080            paged_attn_meta,
1081            flash_meta,
1082        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
1083        let metadata = self.get_metadata();
1084        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1085            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
1086            (Some(_), None) => {
1087                // This can happen if Rust-side user code is wrong
1088                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1089            }
1090            (None, Some(_)) => {
1091                // This should never happen but we handle it anyway
1092                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1093            }
1094            (None, None) => None,
1095        };
1096        let logits = self.model.forward(
1097            &input_ids,
1098            pixel_values,
1099            &seqlen_offsets,
1100            context_lens,
1101            position_ids,
1102            model_specific_args,
1103            paged_attn_meta,
1104            &flash_meta,
1105        )?;
1106        if return_raw_logits {
1107            Ok(ForwardInputsResult::RawLogits { logits })
1108        } else {
1109            Ok(ForwardInputsResult::CausalGeneration { logits })
1110        }
1111    }
1112    async fn sample_causal_gen(
1113        &self,
1114        seqs: &mut [&mut Sequence],
1115        logits: Vec<Tensor>,
1116        prefix_cacher: &mut PrefixCacheManagerV2,
1117        disable_eos_stop: bool,
1118        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1119    ) -> Result<(), candle_core::Error> {
1120        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1121    }
1122    fn category(&self) -> ModelCategory {
1123        ModelCategory::Multimodal {
1124            prefixer: self.prefixer.clone(),
1125        }
1126    }
1127
1128    fn encoder_cache_counters(
1129        &self,
1130    ) -> Option<(
1131        std::sync::Arc<std::sync::atomic::AtomicUsize>,
1132        std::sync::Arc<std::sync::atomic::AtomicUsize>,
1133    )> {
1134        self.model.encoder_cache_counters()
1135    }
1136}
1137
1138impl AnyMoePipelineMixin for MultimodalPipeline {
1139    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1140        self.model.finish_training(gate_model_id)
1141    }
1142    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1143        self.model.get_vars()
1144    }
1145    fn amoe_base_model_trainable_params(&self) -> usize {
1146        self.model.trainable_params()
1147    }
1148    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1149        self.model.take_cached_gating_outputs()
1150    }
1151    fn amoe_create_layers(
1152        &mut self,
1153        model_ids: Vec<String>,
1154        token: &TokenSource,
1155        revision: Option<String>,
1156        match_regex: &str,
1157        config: crate::amoe::AnyMoeConfig,
1158        dtype: candle_core::DType,
1159        dev: &Device,
1160        (prefix, mlp): (String, String),
1161        layers: Vec<usize>,
1162        expert_type: AnyMoeExpertType,
1163        silent: bool,
1164        gate_model_id: Option<String>,
1165    ) -> candle_core::Result<()> {
1166        let mut vbs = Vec::new();
1167        // Precompile regex here
1168        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1169        for model_id in model_ids {
1170            let model_id_str = &model_id;
1171            let model_id = Path::new(&model_id);
1172
1173            let api = {
1174                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1175                let mut api = ApiBuilder::from_cache(cache)
1176                    .with_progress(!silent)
1177                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1178                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1179                    api = api.with_cache_dir(cache_dir);
1180                }
1181                api.build().map_err(candle_core::Error::msg)?
1182            };
1183            let revision = revision.clone().unwrap_or("main".to_string());
1184            let api = api.repo(Repo::with_revision(
1185                model_id_str.clone(),
1186                RepoType::Model,
1187                revision.clone(),
1188            ));
1189
1190            let mut filenames = vec![];
1191            for rfilename in
1192                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1193            {
1194                filenames.push(api_get_file!(api, &rfilename, model_id));
1195            }
1196
1197            let regex = regex.clone();
1198            let match_regex_clone = match_regex.to_string();
1199            let layers_clone = layers.clone();
1200            let vb = from_mmaped_safetensors(
1201                filenames,
1202                vec![],
1203                Some(dtype),
1204                dev,
1205                vec![None],
1206                silent,
1207                None,
1208                move |key| {
1209                    if regex.is_match(&key) {
1210                        // Idx of the last char of the layer id, +1
1211                        // Assumes N.MLP
1212                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1213                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1214                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1215                            .parse::<usize>()
1216                            .unwrap();
1217                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1218                    } else {
1219                        false
1220                    }
1221                },
1222                Arc::new(|_| DeviceForLoadTensor::Base),
1223            )?;
1224            vbs.push(vb);
1225        }
1226
1227        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1228            let model_id_str = &gate_model_id;
1229            let model_id = Path::new(&gate_model_id);
1230
1231            let api = {
1232                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1233                let mut api = ApiBuilder::from_cache(cache)
1234                    .with_progress(!silent)
1235                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1236                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1237                    api = api.with_cache_dir(cache_dir);
1238                }
1239                api.build().map_err(candle_core::Error::msg)?
1240            };
1241            let revision = revision.clone().unwrap_or("main".to_string());
1242            let api = api.repo(Repo::with_revision(
1243                model_id_str.clone(),
1244                RepoType::Model,
1245                revision.clone(),
1246            ));
1247
1248            let mut gate_filenames = vec![];
1249            for rfilename in
1250                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1251            {
1252                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1253            }
1254            assert_eq!(
1255                gate_filenames.len(),
1256                1,
1257                "Gate model ID must contain only one .safetensors file"
1258            );
1259
1260            let vb = from_mmaped_safetensors(
1261                gate_filenames.clone(),
1262                vec![],
1263                Some(dtype),
1264                dev,
1265                vec![None],
1266                silent,
1267                None,
1268                |_| true,
1269                Arc::new(|_| DeviceForLoadTensor::Base),
1270            )?;
1271            info!(
1272                "Loaded gating layers from `{}`",
1273                gate_filenames[0].display()
1274            );
1275            Some(vb)
1276        } else {
1277            None
1278        };
1279
1280        self.model
1281            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1282    }
1283    fn amoe_supported(&self) -> bool {
1284        self.model.amoe_supported()
1285    }
1286}