1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoMultimodalLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalModel, MultimodalModelLoader, MultimodalPromptPrefixer,
8 Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader,
9 Qwen3_5Loader, Qwen3_5MoeLoader, TokenSource, VLlama4Loader, VLlamaLoader,
10};
11use super::{
12 Gemma3nLoader, Gemma4Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
13 Mistral3Loader, MultimodalLoaderType, Phi3VLoader, Qwen2_5VLLoader, VoxtralLoader,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, use_ring, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{
21 calculate_eos_tokens, BeginEndUnkPadTok, ChatTemplateValue, GenerationConfig,
22};
23use crate::pipeline::llg::build_llg_factory;
24use crate::pipeline::loaders::auto_device_map;
25use crate::pipeline::loaders::QuantizationConfigShim;
26use crate::pipeline::sampling::sample_and_add_toks;
27use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
28use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
29use crate::prefix_cacher::PrefixCacheManagerV2;
30use crate::sequence::Sequence;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::utils::varbuilder_utils::DeviceForLoadTensor;
33use crate::utils::{
34 progress::{new_multi_progress, ProgressScopeGuard},
35 tokens::get_token,
36 varbuilder_utils::from_mmaped_safetensors,
37};
38use crate::vision_models::preprocessor_config::PreProcessorConfig;
39use crate::vision_models::processor_config::ProcessorConfig;
40use crate::vision_models::ModelInputs;
41use crate::{
42 api_dir_list, api_get_file, get_paths, get_uqff_paths, multimodal_normal_model_loader,
43 multimodal_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
44 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
45};
46use anyhow::Result;
47use candle_core::{Device, Tensor, Var};
48use either::Either;
49use hf_hub::Cache;
50use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
51use mistralrs_quant::log::once_log_info;
52use mistralrs_quant::{
53 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
54};
55use rand_isaac::Isaac64Rng;
56use regex_automata::meta::Regex;
57use std::any::Any;
58use std::borrow::Cow;
59use std::path::{Path, PathBuf};
60use std::str::FromStr;
61use std::sync::{Arc, RwLock};
62use std::time::Instant;
63use std::{env, fs};
64use tokenizers::Tokenizer;
65use tokio::sync::Mutex;
66use tracing::{info, warn};
67
68pub struct MultimodalPipeline {
69 model: Box<dyn MultimodalModel + Send + Sync>,
70 tokenizer: Arc<Tokenizer>,
71 chat_template: Arc<ChatTemplate>,
72 model_id: String,
73 metadata: Arc<GeneralMetadata>,
74 processor: Arc<dyn Processor + Send + Sync>,
75 preprocessor_config: Arc<PreProcessorConfig>,
76 topology: Option<Topology>,
77 silent: bool,
78 prefixer: Arc<dyn MultimodalPromptPrefixer>,
79 mapper: Box<dyn DeviceMapper + Send + Sync>,
80 organization: IsqOrganization,
81
82 template_filename: Option<PathBuf>,
84 generation_config: Option<PathBuf>,
85 generation_defaults: Option<crate::ModelGenerationDefaults>,
86 config: String,
87 processor_filename: Option<PathBuf>,
88 preprocessor_filename: Option<PathBuf>,
89 imatrix: Option<PathBuf>,
90}
91
92pub struct MultimodalLoader {
94 inner: Box<dyn MultimodalModelLoader>,
95 model_id: String,
96 config: MultimodalSpecificConfig,
97 kind: ModelKind,
98 chat_template: Option<String>,
99 tokenizer_json: Option<String>,
100 xlora_model_id: Option<String>,
101 xlora_order: Option<Ordering>,
102 token_source: RwLock<Option<TokenSource>>,
103 revision: RwLock<Option<String>>,
104 from_uqff: RwLock<Option<Vec<PathBuf>>>,
105 jinja_explicit: Option<String>,
106 hf_cache_path: Option<PathBuf>,
107 lora_adapter_ids: Option<Vec<String>>,
108}
109
110#[derive(Default)]
111pub struct MultimodalLoaderBuilder {
113 model_id: Option<String>,
114 config: MultimodalSpecificConfig,
115 kind: ModelKind,
116 chat_template: Option<String>,
117 tokenizer_json: Option<String>,
118 jinja_explicit: Option<String>,
119 hf_cache_path: Option<PathBuf>,
120 lora_adapter_ids: Option<Vec<String>>,
121}
122
123#[derive(Clone, Default)]
124pub struct MultimodalSpecificConfig {
126 pub topology: Option<Topology>,
127 pub write_uqff: Option<PathBuf>,
128 pub from_uqff: Option<Vec<PathBuf>>,
129 pub max_edge: Option<u32>,
130 pub imatrix: Option<PathBuf>,
131 pub calibration_file: Option<PathBuf>,
132 pub hf_cache_path: Option<PathBuf>,
133 pub matformer_config_path: Option<PathBuf>,
134 pub matformer_slice_name: Option<String>,
135 pub organization: IsqOrganization,
136}
137
138impl MultimodalLoaderBuilder {
139 pub fn new(
140 config: MultimodalSpecificConfig,
141 chat_template: Option<String>,
142 tokenizer_json: Option<String>,
143 model_id: Option<String>,
144 jinja_explicit: Option<String>,
145 ) -> Self {
146 Self {
147 config,
148 chat_template,
149 tokenizer_json,
150 model_id,
151 jinja_explicit,
152 kind: ModelKind::Normal,
153 hf_cache_path: None,
154 ..Default::default()
155 }
156 }
157
158 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
159 self.hf_cache_path = Some(hf_cache_path);
160 self
161 }
162
163 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
164 self.kind = ModelKind::Adapter {
165 adapter: AdapterKind::Lora,
166 };
167 self.lora_adapter_ids = Some(lora_adapter_ids);
168 self
169 }
170
171 pub fn build(self, loader: Option<MultimodalLoaderType>) -> Box<dyn Loader> {
172 let loader: Box<dyn MultimodalModelLoader> = match loader {
173 Some(MultimodalLoaderType::Phi3V) => Box::new(Phi3VLoader),
174 Some(MultimodalLoaderType::Idefics2) => Box::new(Idefics2Loader),
175 Some(MultimodalLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
176 Some(MultimodalLoaderType::LLaVA) => Box::new(LLaVALoader),
177 Some(MultimodalLoaderType::VLlama) => Box::new(VLlamaLoader),
178 Some(MultimodalLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
179 Some(MultimodalLoaderType::Idefics3) => Box::new(Idefics3Loader),
180 Some(MultimodalLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
181 Some(MultimodalLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
182 Some(MultimodalLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
183 Some(MultimodalLoaderType::Gemma3) => Box::new(Gemma3Loader),
184 Some(MultimodalLoaderType::Mistral3) => Box::new(Mistral3Loader),
185 Some(MultimodalLoaderType::Llama4) => Box::new(VLlama4Loader),
186 Some(MultimodalLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
187 Some(MultimodalLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
188 Some(MultimodalLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
189 Some(MultimodalLoaderType::Qwen3_5) => Box::new(Qwen3_5Loader),
190 Some(MultimodalLoaderType::Qwen3_5Moe) => Box::new(Qwen3_5MoeLoader),
191 Some(MultimodalLoaderType::Voxtral) => Box::new(VoxtralLoader),
192 Some(MultimodalLoaderType::Gemma4) => Box::new(Gemma4Loader),
193 None => Box::new(AutoMultimodalLoader),
194 };
195 Box::new(MultimodalLoader {
196 inner: loader,
197 model_id: self.model_id.unwrap(),
198 config: self.config,
199 kind: self.kind,
200 chat_template: self.chat_template,
201 tokenizer_json: self.tokenizer_json,
202 xlora_model_id: None,
203 xlora_order: None,
204 jinja_explicit: self.jinja_explicit,
205 token_source: RwLock::new(None),
206 revision: RwLock::new(None),
207 from_uqff: RwLock::new(None),
208 hf_cache_path: self.hf_cache_path,
209 lora_adapter_ids: self.lora_adapter_ids,
210 })
211 }
212}
213
214impl Loader for MultimodalLoader {
215 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
216 fn load_model_from_hf(
217 &self,
218 revision: Option<String>,
219 token_source: TokenSource,
220 dtype: &dyn TryIntoDType,
221 device: &Device,
222 silent: bool,
223 mapper: DeviceMapSetting,
224 in_situ_quant: Option<IsqType>,
225 paged_attn_config: Option<PagedAttentionConfig>,
226 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
227 let _progress_guard = ProgressScopeGuard::new(silent);
228 let cache = self
229 .hf_cache_path
230 .clone()
231 .map(Cache::new)
232 .unwrap_or_default();
233 GLOBAL_HF_CACHE.get_or_init(|| cache);
234
235 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
236 LocalModelPaths,
237 &token_source,
238 revision.clone(),
239 self,
240 None,
241 None,
242 silent,
243 self.config.from_uqff.is_some()
244 );
245 *self
246 .token_source
247 .write()
248 .expect("Failed to write to token source") = Some(token_source);
249 *self.revision.write().expect("Failed to write to revision") = revision.clone();
250 if let Some(from_uqff) = self.config.from_uqff.clone() {
251 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
252 }
253 self.load_model_from_path(
254 &paths?,
255 dtype,
256 device,
257 silent,
258 mapper,
259 in_situ_quant,
260 paged_attn_config,
261 )
262 }
263
264 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
265 fn load_model_from_path(
266 &self,
267 paths: &Box<dyn ModelPaths>,
268 dtype: &dyn TryIntoDType,
269 device: &Device,
270 silent: bool,
271 mut mapper: DeviceMapSetting,
272 in_situ_quant: Option<IsqType>,
273 mut paged_attn_config: Option<PagedAttentionConfig>,
274 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
275 let _progress_guard = ProgressScopeGuard::new(silent);
276 let config = std::fs::read_to_string(paths.get_config_filename())?;
277
278 if !self.inner.supports_paged_attention(&config) {
279 paged_attn_config = None;
280 }
281
282 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
283
284 let use_nccl = mistralrs_quant::distributed::use_nccl();
285
286 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
287 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
288 let WorkerTransferData::Init { id: _, worker_rank } = payload;
289 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
292 } else if use_nccl || use_ring() {
293 vec![candle_core::Device::new_cuda(0)?]
294 } else {
295 device_map::get_all_similar_devices(device)?
296 };
297 #[cfg(feature = "cuda")]
298 for device in &available_devices {
299 if let Device::Cuda(dev) = device {
300 unsafe { dev.disable_event_tracking() };
301 }
302 }
303 let device = if use_nccl || use_ring() {
304 available_devices[0].clone()
305 } else {
306 device.clone()
307 };
308
309 let matformer_slicing_config = if let Some(matformer_path) =
311 &self.config.matformer_config_path
312 {
313 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
314 info!("Loading Matformer config from {:?}", matformer_path);
315 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
316
317 if let Some(slice_name) = &self.config.matformer_slice_name {
318 info!("Using Matformer slice: {}", slice_name);
319 Some(MatformerSliceConfig::new(slice_name.clone(), config))
320 } else {
321 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
324 None
325 }
326 } else {
327 None
328 };
329
330 let mut max_kv_tokens: Option<usize> = None;
332 if use_nccl || use_ring() {
333 mapper = DeviceMapSetting::DummyNccl {
334 nm_device: available_devices[0].clone(),
335 };
336 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
337 params = params.maybe_promote_to_multimodal();
339 max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
340
341 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
343
344 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
347 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
348 let weight_pack_factor = {
349 let ser_artifacts = unsafe {
350 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
351 };
352 let mut total_pack_factors = 0;
353 let total_tensors = ser_artifacts.tensors().len();
354 for (_, artifact) in ser_artifacts.tensors() {
355 let artifact = artifact.data();
356 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
358 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
359 {
360 QuantizedSerdeType::Hqq => {
361 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
362 .pack_factor(dtype)
363 }
364 QuantizedSerdeType::Gguf => {
365 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
366 .pack_factor(dtype)
367 }
368 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
369 QuantizedSerdeType::Unquant => 1,
370 QuantizedSerdeType::Afq => {
371 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
372 .pack_factor(dtype)
373 }
374 QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
375 QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
376 };
377 total_pack_factors += pack_factor;
378 }
379
380 total_pack_factors / total_tensors
381 };
382
383 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
384 &config,
385 dtype,
386 weight_pack_factor,
387 matformer_slicing_config.as_ref(),
388 )?;
389 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
390 &config,
391 dtype,
392 weight_pack_factor,
393 matformer_slicing_config.as_ref(),
394 )?;
395 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
396 (
397 layer_sizes_in_bytes,
398 non_mapped_size_in_bytes,
399 layer_sizes_sum + non_mapped_size_in_bytes,
400 )
401 } else if let Some(isq) = in_situ_quant {
402 let weight_pack_factor = isq.pack_factor(dtype);
403 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
404 &config,
405 dtype,
406 weight_pack_factor,
407 matformer_slicing_config.as_ref(),
408 )?;
409 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
410 &config,
411 dtype,
412 weight_pack_factor,
413 matformer_slicing_config.as_ref(),
414 )?;
415 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
416 (
417 layer_sizes_in_bytes,
418 non_mapped_size_in_bytes,
419 layer_sizes_sum + non_mapped_size_in_bytes,
420 )
421 } else {
422 let weight_pack_factor =
424 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
425 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
426 &config,
427 dtype,
428 weight_pack_factor,
429 matformer_slicing_config.as_ref(),
430 )?;
431 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
432 &config,
433 dtype,
434 weight_pack_factor,
435 matformer_slicing_config.as_ref(),
436 )?;
437 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
438 (
439 layer_sizes_in_bytes,
440 non_mapped_size_in_bytes,
441 layer_sizes_sum + non_mapped_size_in_bytes,
442 )
443 };
444
445 let new = auto_device_map::get_device_layers(
446 &*self.inner,
447 &config,
448 self.inner.num_layers(&config)?,
449 layer_sizes_in_bytes,
450 non_mapped_size_in_bytes,
451 total_model_size_in_bytes,
452 &available_devices,
453 dtype,
454 ¶ms,
455 paged_attn_config.as_ref(),
456 )?;
457 mapper = DeviceMapSetting::Map(new);
458 }
459
460 let pipeline_mapper = mapper.into_mapper(
461 self.inner.num_layers(&config)?,
462 &device,
463 self.config.topology.as_ref(),
464 &available_devices,
465 )?;
466 let mapper = mapper.into_mapper(
467 self.inner.num_layers(&config)?,
468 &device,
469 self.config.topology.as_ref(),
470 &available_devices,
471 )?;
472 let mut layer_devices = Vec::new();
473 for layer in 0..self.inner.num_layers(&config)? {
474 let device = mapper.device_for(layer, false).cloned();
475 layer_devices.push(device);
476 }
477 let dtype = mapper.get_min_dtype(dtype)?;
478
479 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
482 if mapping_uses_cpu && paged_attn_config.is_some() {
483 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
484 paged_attn_config = None;
485 }
486
487 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
488 if crate::using_flash_attn() {
489 once_log_info("FlashAttention is enabled.");
490 }
491
492 let topology_overrides = self
493 .config
494 .topology
495 .as_ref()
496 .map(|topology| {
497 topology
498 .pattern_overrides()
499 .into_iter()
500 .map(|(regex, layer)| ImmediateIsqOverride {
501 predicate: regex,
502 ty: layer.isq,
503 device: layer.device.clone(),
504 })
505 .collect::<Vec<_>>()
506 })
507 .unwrap_or_default();
508 let has_override_isq = topology_overrides
509 .iter()
510 .any(|override_entry| override_entry.ty.is_some());
511 let topology_requires_post_quant = self
512 .config
513 .topology
514 .as_ref()
515 .is_some_and(|topology| topology.requires_post_quantization());
516
517 let allow_immediate_cli = self.config.imatrix.is_none()
518 && self.config.calibration_file.is_none()
519 && in_situ_quant.is_some();
520
521 let mut immediate_ty = None;
522 let mut immediate_predicates = Vec::new();
523 if allow_immediate_cli {
524 immediate_ty = in_situ_quant;
525 immediate_predicates =
526 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
527 self.inner.immediate_isq_predicates_moqe(&config)?
528 } else {
529 self.inner.immediate_isq_predicates(&config)?
530 };
531 info!("Applying ISQ to {in_situ_quant:?}");
532 if immediate_predicates.is_empty() {
533 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
534 }
535 }
536
537 let use_immediate = allow_immediate_cli || has_override_isq;
538 if use_immediate {
539 let (pool, num_threads) = mistralrs_quant::create_isq_thread_pool(immediate_ty);
540 info!("Applying immediate ISQ in parallel on {num_threads} threads.");
541 mistralrs_quant::set_immediate_isq_with_pool(
542 immediate_ty,
543 immediate_predicates.clone(),
544 topology_overrides.clone(),
545 pool,
546 );
547 }
548
549 let mut loading_isq = if use_immediate {
551 false
552 } else {
553 in_situ_quant.is_some()
554 };
555 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
556 loading_isq = true;
557 }
558 loading_isq |= topology_requires_post_quant;
559 loading_isq |= self.config.from_uqff.is_some();
560
561 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
562 anyhow::bail!(
563 "`imatrix` and `calibration_file` were both specified, this is not allowed."
564 );
565 }
566
567 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
573 loading_isq = false;
574 if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
575 Device::Cpu
576 } else {
577 device.clone()
578 }
579 } else {
580 Device::Cpu
581 };
582
583 let attention_mechanism = if paged_attn_config.is_some() {
584 AttentionImplementation::PagedAttention
585 } else {
586 AttentionImplementation::Eager
587 };
588
589 let multi_progress = Arc::new(new_multi_progress());
590
591 let mut model = if use_nccl || use_ring() {
592 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
593 dtype,
594 &device,
595 &available_devices,
596 silent,
597 &config,
598 loading_isq,
599 self.config.from_uqff.is_some(),
600 self.config.organization,
601 &*self.inner,
602 paths.as_ref(),
603 )?;
604
605 match self.kind {
607 ModelKind::Normal => multimodal_normal_model_loader_sharded!(
608 sharded_vb,
609 config,
610 self.inner,
611 mapper,
612 loading_isq,
613 device.clone(),
614 attention_mechanism,
615 multi_progress.clone(),
616 matformer_slicing_config.clone(),
617 ),
618 _ => unreachable!(),
619 }
620 } else {
621 match self.kind {
622 ModelKind::Normal => multimodal_normal_model_loader!(
623 paths,
624 Some(dtype),
625 &load_device,
626 layer_devices.clone(),
627 config,
628 self.inner,
629 silent,
630 mapper,
631 loading_isq,
632 self.config.from_uqff.is_some(),
633 device.clone(),
634 attention_mechanism,
635 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
636 multi_progress,
637 matformer_slicing_config.clone(),
638 ),
639 _ => unreachable!(),
640 }
641 };
642
643 let processor_config_json = paths
644 .get_processor_config()
645 .as_ref()
646 .map(|f| fs::read_to_string(f).unwrap());
647
648 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
651 {
652 Some(preprocessor_config) => {
653 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
654 }
655 None => processor_config_json.as_deref().map_or_else(
656 PreProcessorConfig::default,
657 |json| match PreProcessorConfig::from_processor_config_json(json) {
658 Ok(config) => config,
659 Err(err) => {
660 warn!(
661 "Failed to synthesize preprocessor config from processor_config.json: {err}"
662 );
663 PreProcessorConfig::default()
664 }
665 },
666 ),
667 };
668 let processor_config: Option<ProcessorConfig> = processor_config_json
669 .as_deref()
670 .map(|json| serde_json::from_str(json).unwrap());
671
672 let processor = self.inner.get_processor(
673 &config,
674 processor_config,
675 preprocessor_config.clone(),
676 self.config.max_edge,
677 ); let tokenizer = get_tokenizer(
680 paths.get_tokenizer_filename(),
681 Some(processor.get_special_tokens()),
682 )?;
683
684 let gen_conf: Option<GenerationConfig> = paths
685 .get_gen_conf_filename()
686 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
687 let chat_template_explicit = paths
688 .get_chat_template_explicit()
689 .as_ref()
690 .map(|x| x.to_string_lossy().to_string());
691 let mut chat_template = get_chat_template(
692 paths,
693 self.jinja_explicit.as_ref(),
694 chat_template_explicit.as_ref(),
695 self.chat_template.as_ref(),
696 None,
697 );
698
699 if chat_template.chat_template.is_none() {
701 if let Some(default_tmpl) = self.inner.default_chat_template(&config) {
702 info!("Using loader's built-in default chat template.");
703 chat_template.chat_template = Some(ChatTemplateValue(Either::Left(default_tmpl)));
704 }
705 }
706
707 if let Some((bos, eos)) = self.inner.default_bos_eos(&config) {
710 if chat_template.bos_token.is_none() {
711 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos)));
712 }
713 if chat_template.eos_token.is_none() {
714 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos)));
715 }
716 }
717
718 if let Some(calibration_file) = &self.config.calibration_file {
719 let calibration_data = std::fs::read_to_string(calibration_file)?;
720 let tokens = tokenizer
722 .encode_fast(calibration_data, false)
723 .map_err(anyhow::Error::msg)?
724 .get_ids()
725 .to_vec();
726 info!(
727 "Collecting imatrix from calibration file `{}` of {} tokens.",
728 calibration_file.display(),
729 tokens.len()
730 );
731 let bos_tok_id = chat_template
732 .bos_tok()
733 .as_deref()
734 .and_then(|tok| tokenizer.token_to_id(tok));
735
736 match self.config.organization {
739 IsqOrganization::Default => model.begin_track_stats()?,
740 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
741 }
742
743 const CHUNK_SIZE: usize = 1024;
744 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
745 let start = Instant::now();
746 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
747 let mut chunk = chunk.to_vec();
748 if let Some(bos_tok_id) = bos_tok_id {
749 chunk.insert(0, bos_tok_id);
750 }
751 let chunk_len = chunk.len();
752
753 let start = Instant::now();
754 let inputs = make_prompt_chunk(
755 0,
756 vec![&chunk],
757 &[0],
758 &load_device,
759 None,
760 false,
761 None,
762 None,
763 None,
764 model.config().sliding_window,
765 )?;
766 let _ = model.forward(
767 &inputs.input,
768 None, &inputs.positions,
770 inputs.context_lens,
771 inputs.position_ids,
772 model.default_model_specific_args(&inputs.input),
773 None,
774 &inputs.flash_meta,
775 )?;
776 match model.cache_mut() {
777 EitherCache::Full(full) => {
778 for layer in &mut *full.lock() {
779 *layer = None
780 }
781 }
782 EitherCache::Normal(normal) => {
783 for layer in &mut *normal.lock().unwrap().0 {
784 layer.reset();
785 }
786 }
787 EitherCache::Hybrid(hybrid) => {
788 hybrid.lock().unwrap().reset();
789 }
790 }
791 let end = Instant::now();
792 info!(
793 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
794 i + 1,
795 end.duration_since(start).as_secs_f32()
796 );
797 }
798 load_device.synchronize()?;
799 let end = Instant::now();
800 info!(
801 "Finished collecting imatrix in {:.2}s",
802 end.duration_since(start).as_secs_f32()
803 );
804 }
805
806 let should_serialize = self.config.write_uqff.is_some();
807 let should_quantize_pass = loading_isq;
808
809 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
810 let imatrix_source = if should_quantize_pass {
811 match (
812 self.config.imatrix.as_ref(),
813 self.config.calibration_file.is_some(),
814 ) {
815 (None, false) => None,
816 (Some(file), false) => Some(ImatrixDataSource::File(file)),
817 (None, true) => Some(ImatrixDataSource::Collected),
818 (Some(_), true) => unreachable!(),
819 }
820 } else {
821 None
822 };
823 if should_quantize_pass {
824 info!("Applying ISQ to all ranks.");
825 } else {
826 info!("Serializing existing ISQ tensors without additional quantization.");
827 }
828 model.quantize(
829 in_situ_quant,
830 device.clone(),
831 self.config.topology.as_ref(),
832 silent,
833 imatrix_source,
834 self.config.organization,
835 should_quantize_pass,
836 self.config.write_uqff.as_ref(),
837 UqffFullSer {
838 tokenizer: &tokenizer,
839 template_filename: paths.get_template_filename(),
840 generation_config: paths.get_gen_conf_filename(),
841 config: config.clone(),
842 processor_filename: paths.get_processor_config(),
843 preprocessor_filename: paths.get_preprocessor_config(),
844 modules: None,
845 module_paths: None,
846 },
847 Arc::new(new_multi_progress()),
848 )?;
849 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
850 model.load_from_artifacts(
851 device.clone(),
852 self.config.topology.as_ref(),
853 silent,
854 from_uqff,
855 )?;
856 }
857
858 let model_metadata = model.model_config();
859 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
860 anyhow::ensure!(
861 !matches!(self.kind, ModelKind::Adapter { .. }),
862 "PagedAttention does not support adapter models."
863 );
864 let cache_config = calculate_cache_config(
865 paged_attn_config.mem_gpu,
866 paged_attn_config.block_size,
867 dtype,
868 paged_attn_config.cache_type,
869 model_metadata.as_ref(),
870 &device,
871 &layer_devices,
872 silent,
873 None,
874 max_kv_tokens,
875 )?;
876 let cache_engine = CacheEngine::new(
877 model_metadata.as_ref(),
878 &cache_config,
879 dtype,
880 &device,
881 layer_devices,
882 )?;
883 (Some(cache_config), Some(cache_engine))
884 } else {
885 (None, None)
886 };
887
888 let max_seq_len = model.max_seq_len();
889 let llg_factory = build_llg_factory(tokenizer.clone())?;
890 let num_hidden_layers = match model.cache() {
891 EitherCache::Full(full) => full.lock().len(),
892 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
893 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
894 };
895 let generation_defaults = gen_conf
896 .as_ref()
897 .and_then(GenerationConfig::generation_defaults);
898 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
899 let sliding_window = model.config().sliding_window;
900 Ok(Arc::new(Mutex::new(MultimodalPipeline {
901 model,
902 tokenizer: tokenizer.into(),
903 chat_template: Arc::new(chat_template),
904 model_id: self.model_id.clone(),
905 metadata: Arc::new(GeneralMetadata {
906 max_seq_len,
907 llg_factory: Some(llg_factory),
908 is_xlora: false,
909 num_hidden_layers,
910 eos_tok: eos,
911 kind: self.kind.clone(),
912 no_kv_cache: false,
913 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
914 activation_dtype: dtype,
915 sliding_window,
916 cache_config,
917 cache_engine,
918 model_metadata: Some(model_metadata),
919 modalities: self.inner.modalities(&config)?,
920 }),
921 processor,
922 prefixer: self.inner.prefixer(&config),
923 preprocessor_config: Arc::new(preprocessor_config),
924 topology: self.config.topology.clone(),
925 silent,
926 organization: self.config.organization,
927 template_filename: paths.get_template_filename().clone(),
928 generation_config: paths.get_gen_conf_filename().cloned(),
929 generation_defaults,
930 config,
931 processor_filename: paths.get_processor_config().clone(),
932 preprocessor_filename: paths.get_preprocessor_config().clone(),
933 mapper: pipeline_mapper,
934 imatrix: self.config.imatrix.clone(),
935 })))
936 }
937
938 fn get_id(&self) -> String {
939 self.model_id.to_string()
940 }
941
942 fn get_kind(&self) -> ModelKind {
943 self.kind.clone()
944 }
945}
946
947impl PreProcessingMixin for MultimodalPipeline {
948 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
949 Some(self.chat_template.clone())
950 }
951 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
952 Some(self.preprocessor_config.clone())
953 }
954 fn get_processor(&self) -> Arc<dyn super::Processor> {
955 self.processor.clone()
956 }
957}
958
959impl IsqPipelineMixin for MultimodalPipeline {
960 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
961 let device = self.device().clone();
962 self.model
963 .quantize(
964 Some(dtype),
965 device,
966 self.topology.as_ref(),
967 self.silent,
968 self.imatrix.as_ref().map(ImatrixDataSource::File),
969 self.organization,
970 true,
971 None,
972 UqffFullSer {
973 tokenizer: &self.tokenizer,
974 template_filename: &self.template_filename,
975 generation_config: self.generation_config.as_ref(),
976 config: self.config.clone(),
977 processor_filename: &self.processor_filename,
978 preprocessor_filename: &self.preprocessor_filename,
979 modules: None,
980 module_paths: None,
981 },
982 Arc::new(new_multi_progress()),
983 )
984 .map_err(anyhow::Error::msg)
985 }
986}
987
988impl CacheManagerMixin for MultimodalPipeline {
989 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
990 match self.model.cache() {
991 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
992 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
993 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
994 }
995 }
996 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
997 match self.model.cache() {
998 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
999 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1000 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1001 }
1002 }
1003 fn set_none_cache(
1004 &self,
1005 seqs: &mut [&mut Sequence],
1006 reset_non_granular: bool,
1007 modify_draft_cache: bool,
1008
1009 load_preallocated_cache: bool,
1010 ) {
1011 match self.model.cache() {
1012 EitherCache::Full(_) => {
1013 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1014 }
1015 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1016 self,
1017 seqs,
1018 modify_draft_cache,
1019 load_preallocated_cache,
1020 ),
1021 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1022 self,
1023 seqs,
1024 modify_draft_cache,
1025 load_preallocated_cache,
1026 ),
1027 }
1028 self.model.reset_model_specific_state();
1032
1033 if reset_non_granular {
1034 self.reset_non_granular_state()
1035 }
1036 }
1037 fn cache(&self) -> &EitherCache {
1038 self.model.cache()
1039 }
1040}
1041
1042impl MetadataMixin for MultimodalPipeline {
1043 fn device(&self) -> Device {
1044 self.model.device().clone()
1045 }
1046 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1047 self.metadata.clone()
1048 }
1049 fn name(&self) -> String {
1050 self.model_id.clone()
1051 }
1052 fn reset_non_granular_state(&self) {
1053 self.model.reset_model_specific_state();
1054 }
1055 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1056 Some(self.tokenizer.clone())
1057 }
1058 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1059 self.generation_defaults.clone()
1060 }
1061 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1062 Some(&*self.mapper)
1063 }
1064}
1065
1066#[async_trait::async_trait]
1067impl Pipeline for MultimodalPipeline {
1068 fn forward_inputs(
1069 &mut self,
1070 inputs: Box<dyn Any>,
1071 return_raw_logits: bool,
1072 ) -> candle_core::Result<ForwardInputsResult> {
1073 let ModelInputs {
1074 input_ids,
1075 seqlen_offsets,
1076 context_lens,
1077 position_ids,
1078 pixel_values,
1079 model_specific_args,
1080 paged_attn_meta,
1081 flash_meta,
1082 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
1083 let metadata = self.get_metadata();
1084 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1085 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
1086 (Some(_), None) => {
1087 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1089 }
1090 (None, Some(_)) => {
1091 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1093 }
1094 (None, None) => None,
1095 };
1096 let logits = self.model.forward(
1097 &input_ids,
1098 pixel_values,
1099 &seqlen_offsets,
1100 context_lens,
1101 position_ids,
1102 model_specific_args,
1103 paged_attn_meta,
1104 &flash_meta,
1105 )?;
1106 if return_raw_logits {
1107 Ok(ForwardInputsResult::RawLogits { logits })
1108 } else {
1109 Ok(ForwardInputsResult::CausalGeneration { logits })
1110 }
1111 }
1112 async fn sample_causal_gen(
1113 &self,
1114 seqs: &mut [&mut Sequence],
1115 logits: Vec<Tensor>,
1116 prefix_cacher: &mut PrefixCacheManagerV2,
1117 disable_eos_stop: bool,
1118 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1119 ) -> Result<(), candle_core::Error> {
1120 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1121 }
1122 fn category(&self) -> ModelCategory {
1123 ModelCategory::Multimodal {
1124 prefixer: self.prefixer.clone(),
1125 }
1126 }
1127
1128 fn encoder_cache_counters(
1129 &self,
1130 ) -> Option<(
1131 std::sync::Arc<std::sync::atomic::AtomicUsize>,
1132 std::sync::Arc<std::sync::atomic::AtomicUsize>,
1133 )> {
1134 self.model.encoder_cache_counters()
1135 }
1136}
1137
1138impl AnyMoePipelineMixin for MultimodalPipeline {
1139 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1140 self.model.finish_training(gate_model_id)
1141 }
1142 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1143 self.model.get_vars()
1144 }
1145 fn amoe_base_model_trainable_params(&self) -> usize {
1146 self.model.trainable_params()
1147 }
1148 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1149 self.model.take_cached_gating_outputs()
1150 }
1151 fn amoe_create_layers(
1152 &mut self,
1153 model_ids: Vec<String>,
1154 token: &TokenSource,
1155 revision: Option<String>,
1156 match_regex: &str,
1157 config: crate::amoe::AnyMoeConfig,
1158 dtype: candle_core::DType,
1159 dev: &Device,
1160 (prefix, mlp): (String, String),
1161 layers: Vec<usize>,
1162 expert_type: AnyMoeExpertType,
1163 silent: bool,
1164 gate_model_id: Option<String>,
1165 ) -> candle_core::Result<()> {
1166 let mut vbs = Vec::new();
1167 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1169 for model_id in model_ids {
1170 let model_id_str = &model_id;
1171 let model_id = Path::new(&model_id);
1172
1173 let api = {
1174 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1175 let mut api = ApiBuilder::from_cache(cache)
1176 .with_progress(!silent)
1177 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1178 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1179 api = api.with_cache_dir(cache_dir);
1180 }
1181 api.build().map_err(candle_core::Error::msg)?
1182 };
1183 let revision = revision.clone().unwrap_or("main".to_string());
1184 let api = api.repo(Repo::with_revision(
1185 model_id_str.clone(),
1186 RepoType::Model,
1187 revision.clone(),
1188 ));
1189
1190 let mut filenames = vec![];
1191 for rfilename in
1192 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1193 {
1194 filenames.push(api_get_file!(api, &rfilename, model_id));
1195 }
1196
1197 let regex = regex.clone();
1198 let match_regex_clone = match_regex.to_string();
1199 let layers_clone = layers.clone();
1200 let vb = from_mmaped_safetensors(
1201 filenames,
1202 vec![],
1203 Some(dtype),
1204 dev,
1205 vec![None],
1206 silent,
1207 None,
1208 move |key| {
1209 if regex.is_match(&key) {
1210 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1213 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1214 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1215 .parse::<usize>()
1216 .unwrap();
1217 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1218 } else {
1219 false
1220 }
1221 },
1222 Arc::new(|_| DeviceForLoadTensor::Base),
1223 )?;
1224 vbs.push(vb);
1225 }
1226
1227 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1228 let model_id_str = &gate_model_id;
1229 let model_id = Path::new(&gate_model_id);
1230
1231 let api = {
1232 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1233 let mut api = ApiBuilder::from_cache(cache)
1234 .with_progress(!silent)
1235 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1236 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1237 api = api.with_cache_dir(cache_dir);
1238 }
1239 api.build().map_err(candle_core::Error::msg)?
1240 };
1241 let revision = revision.clone().unwrap_or("main".to_string());
1242 let api = api.repo(Repo::with_revision(
1243 model_id_str.clone(),
1244 RepoType::Model,
1245 revision.clone(),
1246 ));
1247
1248 let mut gate_filenames = vec![];
1249 for rfilename in
1250 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1251 {
1252 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1253 }
1254 assert_eq!(
1255 gate_filenames.len(),
1256 1,
1257 "Gate model ID must contain only one .safetensors file"
1258 );
1259
1260 let vb = from_mmaped_safetensors(
1261 gate_filenames.clone(),
1262 vec![],
1263 Some(dtype),
1264 dev,
1265 vec![None],
1266 silent,
1267 None,
1268 |_| true,
1269 Arc::new(|_| DeviceForLoadTensor::Base),
1270 )?;
1271 info!(
1272 "Loaded gating layers from `{}`",
1273 gate_filenames[0].display()
1274 );
1275 Some(vb)
1276 } else {
1277 None
1278 };
1279
1280 self.model
1281 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1282 }
1283 fn amoe_supported(&self) -> bool {
1284 self.model.amoe_supported()
1285 }
1286}