1use super::isq::{ImatrixDataSource, UqffFullSer, WeightLoadingMode, WeightLoadingState};
2use super::{
3 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoMultimodalLoader,
4 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
5 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
6 ModelKind, ModelPaths, MultimodalModel, MultimodalModelLoader, MultimodalPromptPrefixer,
7 Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader,
8 Qwen3_5Loader, Qwen3_5MoeLoader, TokenSource, VLlama4Loader, VLlamaLoader,
9};
10use super::{
11 Gemma3nLoader, Gemma4Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
12 Mistral3Loader, MultimodalLoaderType, Phi3VLoader, Qwen2_5VLLoader, VoxtralLoader,
13};
14use crate::attention::ATTENTION_CHUNK_SIZE;
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, use_ring, WorkerTransferData};
17use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
18use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
19use crate::pipeline::chat_template::{
20 calculate_eos_tokens, BeginEndUnkPadTok, ChatTemplateValue, GenerationConfig,
21};
22use crate::pipeline::llg::build_llg_factory;
23use crate::pipeline::loaders::auto_device_map;
24use crate::pipeline::loaders::QuantizationConfigShim;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::text_models_inputs_processor::{make_prompt_chunk, InputMetadata};
27use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
28use crate::prefix_cacher::PrefixCacheManagerV2;
29use crate::sequence::Sequence;
30use crate::utils::tokenizer::get_tokenizer;
31use crate::utils::varbuilder_utils::DeviceForLoadTensor;
32use crate::utils::{
33 progress::{new_multi_progress, ProgressScopeGuard},
34 tokens::get_token,
35 varbuilder_utils::from_mmaped_safetensors,
36};
37use crate::vision_models::preprocessor_config::PreProcessorConfig;
38use crate::vision_models::processor_config::ProcessorConfig;
39use crate::vision_models::ModelInputs;
40use crate::{
41 api_dir_list, api_get_file, get_paths, get_uqff_paths, multimodal_normal_model_loader,
42 multimodal_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
43 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
44};
45use anyhow::Result;
46use either::Either;
47use hanzo_ml::{Device, Tensor, Var};
48use hanzo_quant::log::once_log_info;
49use hanzo_quant::{
50 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use hf_hub::Cache;
53use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
54use rand_isaac::Isaac64Rng;
55use regex_automata::meta::Regex;
56use std::any::Any;
57use std::borrow::Cow;
58use std::path::{Path, PathBuf};
59use std::str::FromStr;
60use std::sync::{Arc, RwLock};
61use std::time::Instant;
62use std::{env, fs};
63use tokenizers::Tokenizer;
64use tokio::sync::Mutex;
65use tracing::{debug, info, trace, warn};
66
67pub struct MultimodalPipeline {
68 model: Box<dyn MultimodalModel + Send + Sync>,
69 tokenizer: Arc<Tokenizer>,
70 chat_template: Arc<ChatTemplate>,
71 model_id: String,
72 metadata: Arc<GeneralMetadata>,
73 processor: Arc<dyn Processor + Send + Sync>,
74 preprocessor_config: Arc<PreProcessorConfig>,
75 topology: Option<Topology>,
76 silent: bool,
77 prefixer: Arc<dyn MultimodalPromptPrefixer>,
78 mapper: Box<dyn DeviceMapper + Send + Sync>,
79 organization: IsqOrganization,
80
81 template_filename: Option<PathBuf>,
83 generation_config: Option<PathBuf>,
84 generation_defaults: Option<crate::ModelGenerationDefaults>,
85 config: String,
86 processor_filename: Option<PathBuf>,
87 preprocessor_filename: Option<PathBuf>,
88 imatrix: Option<PathBuf>,
89}
90
91pub struct MultimodalLoader {
93 inner: Box<dyn MultimodalModelLoader>,
94 model_id: String,
95 config: MultimodalSpecificConfig,
96 kind: ModelKind,
97 chat_template: Option<String>,
98 tokenizer_json: Option<String>,
99 xlora_model_id: Option<String>,
100 xlora_order: Option<Ordering>,
101 token_source: RwLock<Option<TokenSource>>,
102 revision: RwLock<Option<String>>,
103 from_uqff: RwLock<Option<Vec<PathBuf>>>,
104 jinja_explicit: Option<String>,
105 hf_cache_path: Option<PathBuf>,
106 lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Default)]
110pub struct MultimodalLoaderBuilder {
112 model_id: Option<String>,
113 config: MultimodalSpecificConfig,
114 kind: ModelKind,
115 chat_template: Option<String>,
116 tokenizer_json: Option<String>,
117 jinja_explicit: Option<String>,
118 hf_cache_path: Option<PathBuf>,
119 lora_adapter_ids: Option<Vec<String>>,
120}
121
122#[derive(Clone, Default)]
123pub struct MultimodalSpecificConfig {
125 pub topology: Option<Topology>,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub max_edge: Option<u32>,
129 pub imatrix: Option<PathBuf>,
130 pub calibration_file: Option<PathBuf>,
131 pub hf_cache_path: Option<PathBuf>,
132 pub matformer_config_path: Option<PathBuf>,
133 pub matformer_slice_name: Option<String>,
134 pub organization: IsqOrganization,
135}
136
137impl MultimodalLoaderBuilder {
138 pub fn new(
139 config: MultimodalSpecificConfig,
140 chat_template: Option<String>,
141 tokenizer_json: Option<String>,
142 model_id: Option<String>,
143 jinja_explicit: Option<String>,
144 ) -> Self {
145 Self {
146 config,
147 chat_template,
148 tokenizer_json,
149 model_id,
150 jinja_explicit,
151 kind: ModelKind::Normal,
152 hf_cache_path: None,
153 ..Default::default()
154 }
155 }
156
157 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
158 self.hf_cache_path = Some(hf_cache_path);
159 self
160 }
161
162 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
163 self.kind = ModelKind::Adapter {
164 adapter: AdapterKind::Lora,
165 };
166 self.lora_adapter_ids = Some(lora_adapter_ids);
167 self
168 }
169
170 pub fn build(self, loader: Option<MultimodalLoaderType>) -> Box<dyn Loader> {
171 let loader: Box<dyn MultimodalModelLoader> = match loader {
172 Some(MultimodalLoaderType::Phi3V) => Box::new(Phi3VLoader),
173 Some(MultimodalLoaderType::Idefics2) => Box::new(Idefics2Loader),
174 Some(MultimodalLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
175 Some(MultimodalLoaderType::LLaVA) => Box::new(LLaVALoader),
176 Some(MultimodalLoaderType::VLlama) => Box::new(VLlamaLoader),
177 Some(MultimodalLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
178 Some(MultimodalLoaderType::Idefics3) => Box::new(Idefics3Loader),
179 Some(MultimodalLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
180 Some(MultimodalLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
181 Some(MultimodalLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
182 Some(MultimodalLoaderType::Gemma3) => Box::new(Gemma3Loader),
183 Some(MultimodalLoaderType::Mistral3) => Box::new(Mistral3Loader),
184 Some(MultimodalLoaderType::Llama4) => Box::new(VLlama4Loader),
185 Some(MultimodalLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
186 Some(MultimodalLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
187 Some(MultimodalLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
188 Some(MultimodalLoaderType::Qwen3_5) => Box::new(Qwen3_5Loader),
189 Some(MultimodalLoaderType::Qwen3_5Moe) => Box::new(Qwen3_5MoeLoader),
190 Some(MultimodalLoaderType::Voxtral) => Box::new(VoxtralLoader),
191 Some(MultimodalLoaderType::Gemma4) => Box::new(Gemma4Loader),
192 None => Box::new(AutoMultimodalLoader),
193 };
194 Box::new(MultimodalLoader {
195 inner: loader,
196 model_id: self.model_id.unwrap(),
197 config: self.config,
198 kind: self.kind,
199 chat_template: self.chat_template,
200 tokenizer_json: self.tokenizer_json,
201 xlora_model_id: None,
202 xlora_order: None,
203 jinja_explicit: self.jinja_explicit,
204 token_source: RwLock::new(None),
205 revision: RwLock::new(None),
206 from_uqff: RwLock::new(None),
207 hf_cache_path: self.hf_cache_path,
208 lora_adapter_ids: self.lora_adapter_ids,
209 })
210 }
211}
212
213impl Loader for MultimodalLoader {
214 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
215 fn load_model_from_hf(
216 &self,
217 revision: Option<String>,
218 token_source: TokenSource,
219 dtype: &dyn TryIntoDType,
220 device: &Device,
221 silent: bool,
222 mapper: DeviceMapSetting,
223 in_situ_quant: Option<IsqType>,
224 paged_attn_config: Option<PagedAttentionConfig>,
225 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
226 let _progress_guard = ProgressScopeGuard::new(silent);
227 let cache = self
228 .hf_cache_path
229 .clone()
230 .map(Cache::new)
231 .unwrap_or_default();
232 GLOBAL_HF_CACHE.get_or_init(|| cache);
233
234 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
235 LocalModelPaths,
236 &token_source,
237 revision.clone(),
238 self,
239 None,
240 None,
241 silent,
242 self.config.from_uqff.is_some()
243 );
244 *self
245 .token_source
246 .write()
247 .expect("Failed to write to token source") = Some(token_source);
248 *self.revision.write().expect("Failed to write to revision") = revision.clone();
249 if let Some(from_uqff) = self.config.from_uqff.clone() {
250 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
251 }
252 self.load_model_from_path(
253 &paths?,
254 dtype,
255 device,
256 silent,
257 mapper,
258 in_situ_quant,
259 paged_attn_config,
260 )
261 }
262
263 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
264 fn load_model_from_path(
265 &self,
266 paths: &Box<dyn ModelPaths>,
267 dtype: &dyn TryIntoDType,
268 device: &Device,
269 silent: bool,
270 mut mapper: DeviceMapSetting,
271 in_situ_quant: Option<IsqType>,
272 mut paged_attn_config: Option<PagedAttentionConfig>,
273 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274 let _progress_guard = ProgressScopeGuard::new(silent);
275 let config = std::fs::read_to_string(paths.get_config_filename())?;
276
277 if !self.inner.supports_paged_attention(&config) {
278 paged_attn_config = None;
279 }
280
281 debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
282
283 let use_nccl = hanzo_quant::distributed::use_nccl();
284
285 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
286 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
287 let WorkerTransferData::Init { id: _, worker_rank } = payload;
288 vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
291 } else if use_nccl || use_ring() {
292 vec![hanzo_ml::Device::new_cuda(0)?]
293 } else {
294 device_map::get_all_similar_devices(device)?
295 };
296 #[cfg(feature = "cuda")]
297 for device in &available_devices {
298 if let Device::Cuda(dev) = device {
299 unsafe { dev.disable_event_tracking() };
300 }
301 }
302 let device = if use_nccl || use_ring() {
303 available_devices[0].clone()
304 } else {
305 device.clone()
306 };
307
308 let matformer_slicing_config = if let Some(matformer_path) =
310 &self.config.matformer_config_path
311 {
312 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
313 info!("Loading Matformer config from {:?}", matformer_path);
314 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
315
316 if let Some(slice_name) = &self.config.matformer_slice_name {
317 info!("Using Matformer slice: {}", slice_name);
318 Some(MatformerSliceConfig::new(slice_name.clone(), config))
319 } else {
320 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
323 None
324 }
325 } else {
326 None
327 };
328
329 let mut max_kv_tokens: Option<usize> = None;
331 if use_nccl || use_ring() {
332 mapper = DeviceMapSetting::DummyNccl {
333 nm_device: available_devices[0].clone(),
334 };
335 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
336 params = params.maybe_promote_to_multimodal();
338 max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
339
340 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
342
343 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
346 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
347 let weight_pack_factor = {
348 let ser_artifacts =
349 unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
350 let mut total_pack_factors = 0;
351 let total_tensors = ser_artifacts.tensors().len();
352 for (_, artifact) in ser_artifacts.tensors() {
353 let artifact = artifact.data();
354 let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
356 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
357 {
358 QuantizedSerdeType::Hqq => {
359 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
360 .pack_factor(dtype)
361 }
362 QuantizedSerdeType::Gguf => {
363 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
364 .pack_factor(dtype)
365 }
366 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
367 QuantizedSerdeType::Unquant => 1,
368 QuantizedSerdeType::Afq => {
369 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
370 .pack_factor(dtype)
371 }
372 QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
373 QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
374 };
375 total_pack_factors += pack_factor;
376 }
377
378 total_pack_factors / total_tensors
379 };
380
381 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
382 &config,
383 dtype,
384 weight_pack_factor,
385 matformer_slicing_config.as_ref(),
386 )?;
387 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
388 &config,
389 dtype,
390 weight_pack_factor,
391 matformer_slicing_config.as_ref(),
392 )?;
393 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
394 (
395 layer_sizes_in_bytes,
396 non_mapped_size_in_bytes,
397 layer_sizes_sum + non_mapped_size_in_bytes,
398 )
399 } else if let Some(isq) = in_situ_quant {
400 let weight_pack_factor = isq.pack_factor(dtype);
401 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
402 &config,
403 dtype,
404 weight_pack_factor,
405 matformer_slicing_config.as_ref(),
406 )?;
407 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
408 &config,
409 dtype,
410 weight_pack_factor,
411 matformer_slicing_config.as_ref(),
412 )?;
413 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
414 (
415 layer_sizes_in_bytes,
416 non_mapped_size_in_bytes,
417 layer_sizes_sum + non_mapped_size_in_bytes,
418 )
419 } else {
420 let weight_pack_factor =
422 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
423 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
424 &config,
425 dtype,
426 weight_pack_factor,
427 matformer_slicing_config.as_ref(),
428 )?;
429 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
430 &config,
431 dtype,
432 weight_pack_factor,
433 matformer_slicing_config.as_ref(),
434 )?;
435 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
436 (
437 layer_sizes_in_bytes,
438 non_mapped_size_in_bytes,
439 layer_sizes_sum + non_mapped_size_in_bytes,
440 )
441 };
442
443 let new = auto_device_map::get_device_layers(
444 &*self.inner,
445 &config,
446 self.inner.num_layers(&config)?,
447 layer_sizes_in_bytes,
448 non_mapped_size_in_bytes,
449 total_model_size_in_bytes,
450 &available_devices,
451 dtype,
452 ¶ms,
453 paged_attn_config.as_ref(),
454 )?;
455 mapper = DeviceMapSetting::Map(new);
456 }
457
458 let pipeline_mapper = mapper.into_mapper(
459 self.inner.num_layers(&config)?,
460 &device,
461 self.config.topology.as_ref(),
462 &available_devices,
463 )?;
464 let mapper = mapper.into_mapper(
465 self.inner.num_layers(&config)?,
466 &device,
467 self.config.topology.as_ref(),
468 &available_devices,
469 )?;
470 let mut layer_devices = Vec::new();
471 for layer in 0..self.inner.num_layers(&config)? {
472 let device = mapper.device_for(layer, false).cloned();
473 layer_devices.push(device);
474 }
475 let dtype = mapper.get_min_dtype(dtype)?;
476
477 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
480 if mapping_uses_cpu && paged_attn_config.is_some() {
481 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
482 paged_attn_config = None;
483 }
484
485 trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
486 if crate::using_flash_attn() {
487 once_log_info("FlashAttention is enabled.");
488 }
489
490 let topology_overrides = self
491 .config
492 .topology
493 .as_ref()
494 .map(|topology| {
495 topology
496 .pattern_overrides()
497 .into_iter()
498 .map(|(regex, layer)| ImmediateIsqOverride {
499 predicate: regex,
500 ty: layer.isq,
501 device: layer.device.clone(),
502 })
503 .collect::<Vec<_>>()
504 })
505 .unwrap_or_default();
506 let has_override_isq = topology_overrides
507 .iter()
508 .any(|override_entry| override_entry.ty.is_some());
509 let topology_requires_post_quant = self
510 .config
511 .topology
512 .as_ref()
513 .is_some_and(|topology| topology.requires_post_quantization());
514
515 let allow_immediate_cli = self.config.imatrix.is_none()
516 && self.config.calibration_file.is_none()
517 && in_situ_quant.is_some();
518
519 let mut immediate_ty = None;
520 let mut immediate_predicates = Vec::new();
521 if allow_immediate_cli {
522 immediate_ty = in_situ_quant;
523 immediate_predicates =
524 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
525 self.inner.immediate_isq_predicates_moqe(&config)?
526 } else {
527 self.inner.immediate_isq_predicates(&config)?
528 };
529 info!("Applying ISQ to {in_situ_quant:?}");
530 if immediate_predicates.is_empty() {
531 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
532 }
533 }
534
535 let use_immediate = allow_immediate_cli || has_override_isq;
536 if use_immediate {
537 let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
538 info!("Applying immediate ISQ in parallel on {num_threads} threads.");
539 hanzo_quant::set_immediate_isq_with_pool(
540 immediate_ty,
541 immediate_predicates.clone(),
542 topology_overrides.clone(),
543 pool,
544 );
545 }
546
547 let mut loading_isq = if use_immediate {
549 false
550 } else {
551 in_situ_quant.is_some()
552 };
553 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
554 loading_isq = true;
555 }
556 loading_isq |= topology_requires_post_quant;
557 loading_isq |= self.config.from_uqff.is_some();
558
559 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
560 anyhow::bail!(
561 "`imatrix` and `calibration_file` were both specified, this is not allowed."
562 );
563 }
564
565 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
571 loading_isq = false;
572 if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
573 Device::Cpu
574 } else {
575 device.clone()
576 }
577 } else {
578 Device::Cpu
579 };
580
581 let attention_mechanism = if paged_attn_config.is_some() {
582 AttentionImplementation::PagedAttention
583 } else {
584 AttentionImplementation::Eager
585 };
586
587 let multi_progress = Arc::new(new_multi_progress());
588
589 info!(
590 "{}",
591 WeightLoadingMode::from(WeightLoadingState {
592 from_uqff: self.config.from_uqff.is_some(),
593 loading_isq,
594 immediate_isq: use_immediate,
595 write_uqff: self.config.write_uqff.is_some(),
596 })
597 .message("model")
598 );
599
600 let mut model = if use_nccl || use_ring() {
601 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
602 dtype,
603 &device,
604 &available_devices,
605 silent,
606 &config,
607 loading_isq,
608 self.config.from_uqff.is_some(),
609 self.config.organization,
610 &*self.inner,
611 paths.as_ref(),
612 )?;
613
614 match self.kind {
616 ModelKind::Normal => multimodal_normal_model_loader_sharded!(
617 sharded_vb,
618 config,
619 self.inner,
620 mapper,
621 loading_isq,
622 device.clone(),
623 attention_mechanism,
624 multi_progress.clone(),
625 matformer_slicing_config.clone(),
626 ),
627 _ => unreachable!(),
628 }
629 } else {
630 match self.kind {
631 ModelKind::Normal => multimodal_normal_model_loader!(
632 paths,
633 Some(dtype),
634 &load_device,
635 layer_devices.clone(),
636 config,
637 self.inner,
638 silent,
639 mapper,
640 loading_isq,
641 self.config.from_uqff.is_some(),
642 device.clone(),
643 attention_mechanism,
644 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
645 multi_progress,
646 matformer_slicing_config.clone(),
647 ),
648 _ => unreachable!(),
649 }
650 };
651
652 let processor_config_json = paths
653 .get_processor_config()
654 .as_ref()
655 .map(|f| fs::read_to_string(f).unwrap());
656
657 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
660 {
661 Some(preprocessor_config) => {
662 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
663 }
664 None => processor_config_json.as_deref().map_or_else(
665 PreProcessorConfig::default,
666 |json| match PreProcessorConfig::from_processor_config_json(json) {
667 Ok(config) => config,
668 Err(err) => {
669 warn!(
670 "Failed to synthesize preprocessor config from processor_config.json: {err}"
671 );
672 PreProcessorConfig::default()
673 }
674 },
675 ),
676 };
677 let processor_config: Option<ProcessorConfig> = processor_config_json
678 .as_deref()
679 .map(|json| serde_json::from_str(json).unwrap());
680
681 let processor = self.inner.get_processor(
682 &config,
683 processor_config,
684 preprocessor_config.clone(),
685 self.config.max_edge,
686 ); let tokenizer = get_tokenizer(
689 paths.get_tokenizer_filename(),
690 Some(processor.get_special_tokens()),
691 )?;
692
693 let gen_conf: Option<GenerationConfig> = paths
694 .get_gen_conf_filename()
695 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
696 let chat_template_explicit = paths
697 .get_chat_template_explicit()
698 .as_ref()
699 .map(|x| x.to_string_lossy().to_string());
700 let mut chat_template = get_chat_template(
701 paths,
702 self.jinja_explicit.as_ref(),
703 chat_template_explicit.as_ref(),
704 self.chat_template.as_ref(),
705 None,
706 );
707
708 if chat_template.chat_template.is_none() {
710 if let Some(default_tmpl) = self.inner.default_chat_template(&config) {
711 info!("Using loader's built-in default chat template.");
712 chat_template.chat_template = Some(ChatTemplateValue(Either::Left(default_tmpl)));
713 }
714 }
715
716 if let Some((bos, eos)) = self.inner.default_bos_eos(&config) {
719 if chat_template.bos_token.is_none() {
720 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos)));
721 }
722 if chat_template.eos_token.is_none() {
723 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos)));
724 }
725 }
726
727 if let Some(calibration_file) = &self.config.calibration_file {
728 let calibration_data = std::fs::read_to_string(calibration_file)?;
729 let tokens = tokenizer
731 .encode_fast(calibration_data, false)
732 .map_err(anyhow::Error::msg)?
733 .get_ids()
734 .to_vec();
735 info!(
736 "Collecting imatrix from calibration file `{}` of {} tokens.",
737 calibration_file.display(),
738 tokens.len()
739 );
740 let bos_tok_id = chat_template
741 .bos_tok()
742 .as_deref()
743 .and_then(|tok| tokenizer.token_to_id(tok));
744
745 match self.config.organization {
748 IsqOrganization::Default => model.begin_track_stats()?,
749 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
750 }
751
752 const CHUNK_SIZE: usize = 1024;
753 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
754 let start = Instant::now();
755 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
756 let mut chunk = chunk.to_vec();
757 if let Some(bos_tok_id) = bos_tok_id {
758 chunk.insert(0, bos_tok_id);
759 }
760 let chunk_len = chunk.len();
761
762 let start = Instant::now();
763 let inputs = make_prompt_chunk(
764 0,
765 vec![&chunk],
766 &[0],
767 &load_device,
768 None,
769 false,
770 None,
771 None,
772 None,
773 model.config().sliding_window,
774 )?;
775 let _ = model.forward(
776 &inputs.input,
777 None, &inputs.positions,
779 inputs.context_lens,
780 inputs.position_ids,
781 model.default_model_specific_args(&inputs.input),
782 None,
783 &inputs.flash_meta,
784 )?;
785 match model.cache_mut() {
786 EitherCache::Full(full) => {
787 for layer in &mut *full.lock() {
788 *layer = None
789 }
790 }
791 EitherCache::Normal(normal) => {
792 for layer in &mut *normal.lock().unwrap().0 {
793 layer.reset();
794 }
795 }
796 EitherCache::Hybrid(hybrid) => {
797 hybrid.lock().unwrap().reset();
798 }
799 }
800 let end = Instant::now();
801 info!(
802 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
803 i + 1,
804 end.duration_since(start).as_secs_f32()
805 );
806 }
807 load_device.synchronize()?;
808 let end = Instant::now();
809 info!(
810 "Finished collecting imatrix in {:.2}s",
811 end.duration_since(start).as_secs_f32()
812 );
813 }
814
815 let should_serialize = self.config.write_uqff.is_some();
816 let should_quantize_pass = loading_isq;
817
818 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
819 let imatrix_source = if should_quantize_pass {
820 match (
821 self.config.imatrix.as_ref(),
822 self.config.calibration_file.is_some(),
823 ) {
824 (None, false) => None,
825 (Some(file), false) => Some(ImatrixDataSource::File(file)),
826 (None, true) => Some(ImatrixDataSource::Collected),
827 (Some(_), true) => unreachable!(),
828 }
829 } else {
830 None
831 };
832 if should_quantize_pass {
833 debug!("Applying ISQ to all ranks.");
834 } else {
835 debug!("Serializing existing ISQ tensors without additional quantization.");
836 }
837 model.quantize(
838 in_situ_quant,
839 device.clone(),
840 self.config.topology.as_ref(),
841 silent,
842 imatrix_source,
843 self.config.organization,
844 should_quantize_pass,
845 self.config.write_uqff.as_ref(),
846 UqffFullSer {
847 tokenizer: &tokenizer,
848 template_filename: paths.get_template_filename(),
849 generation_config: paths.get_gen_conf_filename(),
850 config: config.clone(),
851 processor_filename: paths.get_processor_config(),
852 preprocessor_filename: paths.get_preprocessor_config(),
853 modules: None,
854 module_paths: None,
855 },
856 Arc::new(new_multi_progress()),
857 )?;
858 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
859 model.load_from_artifacts(
860 device.clone(),
861 self.config.topology.as_ref(),
862 silent,
863 from_uqff,
864 )?;
865 }
866
867 let model_metadata = model.model_config();
868 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
869 anyhow::ensure!(
870 !matches!(self.kind, ModelKind::Adapter { .. }),
871 "PagedAttention does not support adapter models."
872 );
873 let cache_config = calculate_cache_config(
874 paged_attn_config.mem_gpu,
875 paged_attn_config.block_size,
876 dtype,
877 paged_attn_config.cache_type,
878 model_metadata.as_ref(),
879 &device,
880 &layer_devices,
881 silent,
882 None,
883 max_kv_tokens,
884 )?;
885 let cache_engine = CacheEngine::new(
886 model_metadata.as_ref(),
887 &cache_config,
888 dtype,
889 &device,
890 layer_devices,
891 )?;
892 (Some(cache_config), Some(cache_engine))
893 } else {
894 (None, None)
895 };
896
897 let max_seq_len = model.max_seq_len();
898 let llg_factory = build_llg_factory(tokenizer.clone())?;
899 let num_hidden_layers = match model.cache() {
900 EitherCache::Full(full) => full.lock().len(),
901 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
902 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
903 };
904 let generation_defaults = gen_conf
905 .as_ref()
906 .and_then(GenerationConfig::generation_defaults);
907 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
908 let sliding_window = model.config().sliding_window;
909 Ok(Arc::new(Mutex::new(MultimodalPipeline {
910 model,
911 tokenizer: tokenizer.into(),
912 chat_template: Arc::new(chat_template),
913 model_id: self.model_id.clone(),
914 metadata: Arc::new(GeneralMetadata {
915 max_seq_len,
916 llg_factory: Some(llg_factory),
917 is_xlora: false,
918 num_hidden_layers,
919 eos_tok: eos,
920 kind: self.kind.clone(),
921 no_kv_cache: false,
922 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
923 activation_dtype: dtype,
924 sliding_window,
925 cache_config,
926 cache_engine,
927 model_metadata: Some(model_metadata),
928 modalities: self.inner.modalities(&config)?,
929 }),
930 processor,
931 prefixer: self.inner.prefixer(&config),
932 preprocessor_config: Arc::new(preprocessor_config),
933 topology: self.config.topology.clone(),
934 silent,
935 organization: self.config.organization,
936 template_filename: paths.get_template_filename().clone(),
937 generation_config: paths.get_gen_conf_filename().cloned(),
938 generation_defaults,
939 config,
940 processor_filename: paths.get_processor_config().clone(),
941 preprocessor_filename: paths.get_preprocessor_config().clone(),
942 mapper: pipeline_mapper,
943 imatrix: self.config.imatrix.clone(),
944 })))
945 }
946
947 fn get_id(&self) -> String {
948 self.model_id.to_string()
949 }
950
951 fn get_kind(&self) -> ModelKind {
952 self.kind.clone()
953 }
954}
955
956impl PreProcessingMixin for MultimodalPipeline {
957 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
958 Some(self.chat_template.clone())
959 }
960 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
961 Some(self.preprocessor_config.clone())
962 }
963 fn get_processor(&self) -> Arc<dyn super::Processor> {
964 self.processor.clone()
965 }
966}
967
968impl IsqPipelineMixin for MultimodalPipeline {
969 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
970 let device = self.device().clone();
971 self.model
972 .quantize(
973 Some(dtype),
974 device,
975 self.topology.as_ref(),
976 self.silent,
977 self.imatrix.as_ref().map(ImatrixDataSource::File),
978 self.organization,
979 true,
980 None,
981 UqffFullSer {
982 tokenizer: &self.tokenizer,
983 template_filename: &self.template_filename,
984 generation_config: self.generation_config.as_ref(),
985 config: self.config.clone(),
986 processor_filename: &self.processor_filename,
987 preprocessor_filename: &self.preprocessor_filename,
988 modules: None,
989 module_paths: None,
990 },
991 Arc::new(new_multi_progress()),
992 )
993 .map_err(anyhow::Error::msg)
994 }
995}
996
997impl CacheManagerMixin for MultimodalPipeline {
998 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
999 match self.model.cache() {
1000 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1001 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1002 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1003 }
1004 }
1005 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1006 match self.model.cache() {
1007 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1008 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1009 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1010 }
1011 }
1012 fn set_none_cache(
1013 &self,
1014 seqs: &mut [&mut Sequence],
1015 reset_non_granular: bool,
1016 modify_draft_cache: bool,
1017
1018 load_preallocated_cache: bool,
1019 ) {
1020 match self.model.cache() {
1021 EitherCache::Full(_) => {
1022 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1023 }
1024 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1025 self,
1026 seqs,
1027 modify_draft_cache,
1028 load_preallocated_cache,
1029 ),
1030 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1031 self,
1032 seqs,
1033 modify_draft_cache,
1034 load_preallocated_cache,
1035 ),
1036 }
1037 self.model.reset_model_specific_state();
1041
1042 if reset_non_granular {
1043 self.reset_non_granular_state()
1044 }
1045 }
1046 fn cache(&self) -> &EitherCache {
1047 self.model.cache()
1048 }
1049}
1050
1051impl MetadataMixin for MultimodalPipeline {
1052 fn device(&self) -> Device {
1053 self.model.device().clone()
1054 }
1055 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1056 self.metadata.clone()
1057 }
1058 fn name(&self) -> String {
1059 self.model_id.clone()
1060 }
1061 fn reset_non_granular_state(&self) {
1062 self.model.reset_model_specific_state();
1063 }
1064 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1065 Some(self.tokenizer.clone())
1066 }
1067 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1068 self.generation_defaults.clone()
1069 }
1070 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1071 Some(&*self.mapper)
1072 }
1073}
1074
1075impl crate::speculative::driver::SpeculativePipelineExt for MultimodalPipeline {
1076 fn has_speculative_proposer(&self) -> bool {
1077 self.model.has_speculative_proposer()
1078 }
1079
1080 fn speculative_proposal_len(&self) -> Option<usize> {
1081 self.model.speculative_proposal_len()
1082 }
1083
1084 fn speculative_target_hiddens(
1085 &self,
1086 rows: &[(usize, usize)],
1087 ) -> hanzo_ml::Result<Option<Tensor>> {
1088 self.model.speculative_target_hiddens(rows)
1089 }
1090
1091 fn speculative_propose(
1092 &mut self,
1093 ctx: crate::speculative::SpeculativeProposeBatchCtx<'_>,
1094 ) -> hanzo_ml::Result<Option<crate::speculative::SpeculativeProposalBatch>> {
1095 self.model.speculative_propose(ctx)
1096 }
1097
1098 fn build_speculative_verify_inputs(
1099 &self,
1100 input_meta: InputMetadata,
1101 ) -> hanzo_ml::Result<Box<dyn Any>> {
1102 let model_specific_args = self.model.default_model_specific_args(&input_meta.input);
1103 Ok(Box::new(ModelInputs {
1104 input_ids: input_meta.input,
1105 seqlen_offsets: input_meta.positions,
1106 context_lens: input_meta.context_lens,
1107 position_ids: input_meta.position_ids,
1108 pixel_values: None,
1109 model_specific_args,
1110 paged_attn_meta: input_meta.paged_attn_meta,
1111 flash_meta: input_meta.flash_meta,
1112 }))
1113 }
1114}
1115
1116#[async_trait::async_trait]
1117impl Pipeline for MultimodalPipeline {
1118 fn forward_inputs(
1119 &mut self,
1120 inputs: Box<dyn Any>,
1121 return_raw_logits: bool,
1122 ) -> hanzo_ml::Result<ForwardInputsResult> {
1123 let ModelInputs {
1124 input_ids,
1125 seqlen_offsets,
1126 context_lens,
1127 position_ids,
1128 pixel_values,
1129 model_specific_args,
1130 paged_attn_meta,
1131 flash_meta,
1132 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
1133 let metadata = self.get_metadata();
1134 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1135 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
1136 (Some(_), None) => {
1137 hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1139 }
1140 (None, Some(_)) => {
1141 hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1143 }
1144 (None, None) => None,
1145 };
1146 let logits = self.model.forward(
1147 &input_ids,
1148 pixel_values,
1149 &seqlen_offsets,
1150 context_lens,
1151 position_ids,
1152 model_specific_args,
1153 paged_attn_meta,
1154 &flash_meta,
1155 )?;
1156 if return_raw_logits {
1157 Ok(ForwardInputsResult::RawLogits { logits })
1158 } else {
1159 Ok(ForwardInputsResult::CausalGeneration { logits })
1160 }
1161 }
1162
1163 fn attach_speculative(
1164 &mut self,
1165 config: crate::speculative::SpeculativeConfig,
1166 ) -> hanzo_ml::Result<()> {
1167 if matches!(config, crate::speculative::SpeculativeConfig::Mtp(_))
1168 && self.get_metadata().cache_engine.is_none()
1169 {
1170 hanzo_ml::bail!(
1171 "MTP speculative decoding currently requires PagedAttention for this pipeline."
1172 );
1173 }
1174 if let Some(info) = self.model.attach_speculative(config)? {
1175 self.model.log_speculative_attach(&info);
1176 }
1177 Ok(())
1178 }
1179
1180 #[allow(clippy::too_many_arguments)]
1181 async fn try_sample_speculative_causal_gen(
1182 &mut self,
1183 seqs: &mut [&mut Sequence],
1184 logits: &[Tensor],
1185 prefix_cacher: &mut PrefixCacheManagerV2,
1186 disable_eos_stop: bool,
1187 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1188 metadata: Option<crate::pipeline::text_models_inputs_processor::PagedAttentionMeta>,
1189 ) -> hanzo_ml::Result<bool> {
1190 if !self.model.has_speculative_proposer() {
1191 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1192 return Ok(false);
1193 }
1194
1195 let general_metadata = self.get_metadata();
1196 if let Some(cache_engine) = general_metadata.cache_engine.as_ref() {
1197 let Some(metadata) = metadata else {
1198 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1199 return Ok(false);
1200 };
1201 let cache = crate::speculative::cache::PagedSpeculativeCacheAccess::new(
1202 &metadata,
1203 cache_engine,
1204 );
1205 return crate::speculative::driver::try_sample_speculative_causal_gen(
1206 self,
1207 seqs,
1208 logits,
1209 prefix_cacher,
1210 disable_eos_stop,
1211 rng,
1212 &cache,
1213 )
1214 .await;
1215 }
1216
1217 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1218 Ok(false)
1219 }
1220
1221 async fn sample_causal_gen(
1222 &self,
1223 seqs: &mut [&mut Sequence],
1224 logits: Vec<Tensor>,
1225 prefix_cacher: &mut PrefixCacheManagerV2,
1226 disable_eos_stop: bool,
1227 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1228 ) -> Result<(), hanzo_ml::Error> {
1229 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1230 }
1231 fn category(&self) -> ModelCategory {
1232 ModelCategory::Multimodal {
1233 prefixer: self.prefixer.clone(),
1234 }
1235 }
1236
1237 fn encoder_cache_counters(
1238 &self,
1239 ) -> Option<(
1240 std::sync::Arc<std::sync::atomic::AtomicUsize>,
1241 std::sync::Arc<std::sync::atomic::AtomicUsize>,
1242 )> {
1243 self.model.encoder_cache_counters()
1244 }
1245}
1246
1247impl AnyMoePipelineMixin for MultimodalPipeline {
1248 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
1249 self.model.finish_training(gate_model_id)
1250 }
1251 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1252 self.model.get_vars()
1253 }
1254 fn amoe_base_model_trainable_params(&self) -> usize {
1255 self.model.trainable_params()
1256 }
1257 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1258 self.model.take_cached_gating_outputs()
1259 }
1260 fn amoe_create_layers(
1261 &mut self,
1262 model_ids: Vec<String>,
1263 token: &TokenSource,
1264 revision: Option<String>,
1265 match_regex: &str,
1266 config: crate::amoe::AnyMoeConfig,
1267 dtype: hanzo_ml::DType,
1268 dev: &Device,
1269 (prefix, mlp): (String, String),
1270 layers: Vec<usize>,
1271 expert_type: AnyMoeExpertType,
1272 silent: bool,
1273 gate_model_id: Option<String>,
1274 ) -> hanzo_ml::Result<()> {
1275 let mut vbs = Vec::new();
1276 let regex = Regex::new(match_regex).map_err(hanzo_ml::Error::msg)?;
1278 for model_id in model_ids {
1279 let model_id_str = &model_id;
1280 let model_id = Path::new(&model_id);
1281
1282 let api = {
1283 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1284 let mut api = ApiBuilder::from_cache(cache)
1285 .with_progress(!silent)
1286 .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1287 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1288 api = api.with_cache_dir(cache_dir);
1289 }
1290 api.build().map_err(hanzo_ml::Error::msg)?
1291 };
1292 let revision = revision.clone().unwrap_or("main".to_string());
1293 let api = api.repo(Repo::with_revision(
1294 model_id_str.clone(),
1295 RepoType::Model,
1296 revision.clone(),
1297 ));
1298
1299 let mut filenames = vec![];
1300 for rfilename in api_dir_list!(api, model_id, true, &revision)
1301 .filter(|x| x.ends_with(".safetensors"))
1302 {
1303 filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1304 }
1305
1306 let regex = regex.clone();
1307 let match_regex_clone = match_regex.to_string();
1308 let layers_clone = layers.clone();
1309 let vb = from_mmaped_safetensors(
1310 filenames,
1311 vec![],
1312 Some(dtype),
1313 dev,
1314 vec![None],
1315 silent,
1316 None,
1317 move |key| {
1318 if regex.is_match(&key) {
1319 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1322 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1323 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1324 .parse::<usize>()
1325 .unwrap();
1326 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1327 } else {
1328 false
1329 }
1330 },
1331 Arc::new(|_| DeviceForLoadTensor::Base),
1332 )?;
1333 vbs.push(vb);
1334 }
1335
1336 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1337 let model_id_str = &gate_model_id;
1338 let model_id = Path::new(&gate_model_id);
1339
1340 let api = {
1341 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1342 let mut api = ApiBuilder::from_cache(cache)
1343 .with_progress(!silent)
1344 .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1345 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1346 api = api.with_cache_dir(cache_dir);
1347 }
1348 api.build().map_err(hanzo_ml::Error::msg)?
1349 };
1350 let revision = revision.clone().unwrap_or("main".to_string());
1351 let api = api.repo(Repo::with_revision(
1352 model_id_str.clone(),
1353 RepoType::Model,
1354 revision.clone(),
1355 ));
1356
1357 let mut gate_filenames = vec![];
1358 for rfilename in api_dir_list!(api, model_id, true, &revision)
1359 .filter(|x| x.ends_with(".safetensors"))
1360 {
1361 gate_filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1362 }
1363 assert_eq!(
1364 gate_filenames.len(),
1365 1,
1366 "Gate model ID must contain only one .safetensors file"
1367 );
1368
1369 let vb = from_mmaped_safetensors(
1370 gate_filenames.clone(),
1371 vec![],
1372 Some(dtype),
1373 dev,
1374 vec![None],
1375 silent,
1376 None,
1377 |_| true,
1378 Arc::new(|_| DeviceForLoadTensor::Base),
1379 )?;
1380 info!(
1381 "Loaded gating layers from `{}`",
1382 gate_filenames[0].display()
1383 );
1384 Some(vb)
1385 } else {
1386 None
1387 };
1388
1389 self.model
1390 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1391 }
1392 fn amoe_supported(&self) -> bool {
1393 self.model.amoe_supported()
1394 }
1395}