1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
14 GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
15 MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16 Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38 progress::{new_multi_progress, ProgressScopeGuard},
39 tokens::get_token,
40 varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use candle_core::{Device, Tensor, Var};
50use hf_hub::Cache;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::log::once_log_info;
53use mistralrs_quant::{
54 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
55};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{info, warn};
68
69pub struct NormalPipeline {
70 model: Box<dyn NormalModel + Send + Sync>,
71 tokenizer: Arc<Tokenizer>,
72 no_kv_cache: bool,
73 chat_template: Arc<ChatTemplate>,
74 non_granular_state: Option<NonGranularState>,
75 model_id: String,
76 metadata: Arc<GeneralMetadata>,
77 topology: Option<Topology>,
78 silent: bool,
79 organization: IsqOrganization,
80 template_filename: Option<PathBuf>,
82 generation_config: Option<PathBuf>,
83 generation_defaults: Option<crate::ModelGenerationDefaults>,
84 config: String,
85 imatrix: Option<PathBuf>,
86 mapper: Box<dyn DeviceMapper + Send + Sync>,
87}
88
89pub struct NormalLoader {
91 inner: Box<dyn NormalModelLoader>,
92 model_id: String,
93 config: NormalSpecificConfig,
94 xlora_model_id: Option<String>,
95 lora_adapter_ids: Option<Vec<String>>,
96 kind: ModelKind,
97 xlora_order: Option<Ordering>,
98 no_kv_cache: bool,
99 chat_template: Option<String>,
100 tokenizer_json: Option<String>,
101 tgt_non_granular_index: Option<usize>,
102 token_source: RwLock<Option<TokenSource>>,
103 revision: RwLock<Option<String>>,
104 from_uqff: RwLock<Option<Vec<PathBuf>>>,
105 jinja_explicit: Option<String>,
106 hf_cache_path: Option<PathBuf>,
107}
108
109#[derive(Default)]
110pub struct NormalLoaderBuilder {
112 model_id: Option<String>,
113 config: NormalSpecificConfig,
114 xlora_model_id: Option<String>,
115 lora_adapter_ids: Option<Vec<String>>,
116 kind: ModelKind,
117 xlora_order: Option<Ordering>,
118 no_kv_cache: bool,
119 chat_template: Option<String>,
120 tokenizer_json: Option<String>,
121 tgt_non_granular_index: Option<usize>,
122 jinja_explicit: Option<String>,
123 hf_cache_path: Option<PathBuf>,
124}
125
126#[derive(Clone, Default)]
127pub struct NormalSpecificConfig {
129 pub topology: Option<Topology>,
130 pub organization: IsqOrganization,
131 pub write_uqff: Option<PathBuf>,
132 pub from_uqff: Option<Vec<PathBuf>>,
133 pub imatrix: Option<PathBuf>,
134 pub calibration_file: Option<PathBuf>,
135 pub hf_cache_path: Option<PathBuf>,
136 pub matformer_config_path: Option<PathBuf>,
137 pub matformer_slice_name: Option<String>,
138}
139
140impl NormalLoaderBuilder {
141 pub fn new(
142 config: NormalSpecificConfig,
143 chat_template: Option<String>,
144 tokenizer_json: Option<String>,
145 model_id: Option<String>,
146 no_kv_cache: bool,
147 jinja_explicit: Option<String>,
148 ) -> Self {
149 Self {
150 config,
151 chat_template,
152 tokenizer_json,
153 model_id,
154 kind: ModelKind::Normal,
155 jinja_explicit,
156 no_kv_cache,
157 ..Default::default()
158 }
159 }
160
161 fn with_adapter(
162 mut self,
163 xlora_model_id: String,
164 xlora_order: Ordering,
165 no_kv_cache: bool,
166 tgt_non_granular_index: Option<usize>,
167 ) -> Self {
168 self.xlora_model_id = Some(xlora_model_id);
169 self.xlora_order = Some(xlora_order);
170 self.no_kv_cache = no_kv_cache;
171 self.tgt_non_granular_index = tgt_non_granular_index;
172 self.model_id = if let Some(id) = self.model_id {
173 Some(id)
174 } else {
175 info!(
176 "Using adapter base model ID: `{}`",
177 self.xlora_order.as_ref().unwrap().base_model_id
178 );
179 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
180 };
181 self
182 }
183
184 pub fn with_xlora(
185 mut self,
186 xlora_model_id: String,
187 xlora_order: Ordering,
188 no_kv_cache: bool,
189 tgt_non_granular_index: Option<usize>,
190 ) -> Self {
191 self.kind = ModelKind::Adapter {
192 adapter: AdapterKind::XLora,
193 };
194 self.with_adapter(
195 xlora_model_id,
196 xlora_order,
197 no_kv_cache,
198 tgt_non_granular_index,
199 )
200 }
201
202 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
203 self.kind = ModelKind::Adapter {
204 adapter: AdapterKind::Lora,
205 };
206 self.lora_adapter_ids = Some(lora_adapter_ids);
207 self
208 }
209
210 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
211 self.hf_cache_path = Some(hf_cache_path);
212 self
213 }
214
215 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
218 let loader: Box<dyn NormalModelLoader> = match loader_tp {
219 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
220 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
221 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
222 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
223 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
224 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
225 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
226 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
227 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
228 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
229 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
230 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
231 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
232 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
233 Some(NormalLoaderType::GLM4MoeLite) => Box::new(GLM4MoeLiteLoader),
234 Some(NormalLoaderType::GLM4Moe) => Box::new(GLM4MoeLoader),
235 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
236 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
237 Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
238 Some(NormalLoaderType::GptOss) => Box::new(GptOssLoader),
239 Some(NormalLoaderType::Qwen3Next) => Box::new(Qwen3NextLoader),
240 None => Box::new(AutoNormalLoader),
241 };
242 Ok(Box::new(NormalLoader {
243 inner: loader,
244 model_id: self.model_id.unwrap(),
245 config: self.config,
246 xlora_model_id: self.xlora_model_id,
247 lora_adapter_ids: self.lora_adapter_ids,
248 kind: self.kind,
249 xlora_order: self.xlora_order,
250 no_kv_cache: self.no_kv_cache,
251 chat_template: self.chat_template,
252 tokenizer_json: self.tokenizer_json,
253 tgt_non_granular_index: self.tgt_non_granular_index,
254 jinja_explicit: self.jinja_explicit,
255 token_source: RwLock::new(None),
256 revision: RwLock::new(None),
257 from_uqff: RwLock::new(None),
258 hf_cache_path: self.hf_cache_path,
259 }))
260 }
261}
262
263impl Loader for NormalLoader {
264 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
265 fn load_model_from_hf(
266 &self,
267 revision: Option<String>,
268 token_source: TokenSource,
269 dtype: &dyn TryIntoDType,
270 device: &Device,
271 silent: bool,
272 mapper: DeviceMapSetting,
273 in_situ_quant: Option<IsqType>,
274 paged_attn_config: Option<PagedAttentionConfig>,
275 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
276 let _progress_guard = ProgressScopeGuard::new(silent);
277 let cache = self
278 .hf_cache_path
279 .clone()
280 .map(Cache::new)
281 .unwrap_or_default();
282 GLOBAL_HF_CACHE.get_or_init(|| cache);
283
284 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
285 LocalModelPaths,
286 &token_source,
287 revision.clone(),
288 self,
289 None,
290 None,
291 silent,
292 self.config.from_uqff.is_some()
293 );
294 *self
295 .token_source
296 .write()
297 .expect("Failed to write to token source") = Some(token_source);
298 *self.revision.write().expect("Failed to write to revision") = revision.clone();
299 if let Some(from_uqff) = self.config.from_uqff.clone() {
300 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
301 }
302 self.load_model_from_path(
303 &paths?,
304 dtype,
305 device,
306 silent,
307 mapper,
308 in_situ_quant,
309 paged_attn_config,
310 )
311 }
312
313 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
314 fn load_model_from_path(
315 &self,
316 paths: &Box<dyn ModelPaths>,
317 dtype: &dyn TryIntoDType,
318 device: &Device,
319 silent: bool,
320 mut mapper: DeviceMapSetting,
321 in_situ_quant: Option<IsqType>,
322 mut paged_attn_config: Option<PagedAttentionConfig>,
323 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
324 let _progress_guard = ProgressScopeGuard::new(silent);
325 let config = std::fs::read_to_string(paths.get_config_filename())?;
326
327 if !self.inner.supports_paged_attention(&config)? {
328 paged_attn_config = None;
329 }
330
331 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
332
333 let use_nccl = mistralrs_quant::distributed::use_nccl();
334
335 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
336 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
337 let WorkerTransferData::Init { id: _, worker_rank } = payload;
338 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
339 } else if use_nccl {
340 vec![candle_core::Device::new_cuda(0)?]
341 } else {
342 device_map::get_all_similar_devices(device)?
343 };
344 #[cfg(feature = "cuda")]
345 for device in &available_devices {
346 if let Device::Cuda(dev) = device {
347 unsafe { dev.disable_event_tracking() };
348 }
349 }
350 let device = if use_nccl || cfg!(feature = "ring") {
351 available_devices[0].clone()
352 } else {
353 device.clone()
354 };
355
356 let mut max_kv_tokens: Option<usize> = None;
358 if use_nccl || cfg!(feature = "ring") {
359 mapper = DeviceMapSetting::DummyNccl {
360 nm_device: available_devices[0].clone(),
361 };
362 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
363 max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
364 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
366
367 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
370 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
371 let weight_pack_factor = {
372 let ser_artifacts = unsafe {
373 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
374 };
375 let mut total_pack_factors = 0;
376 let total_tensors = ser_artifacts.tensors().len();
377 for (_, artifact) in ser_artifacts.tensors() {
378 let artifact = artifact.data();
379 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
381 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
382 {
383 QuantizedSerdeType::Hqq => {
384 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
385 .pack_factor(dtype)
386 }
387 QuantizedSerdeType::Gguf => {
388 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
389 .pack_factor(dtype)
390 }
391 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
392 QuantizedSerdeType::Unquant => 1,
393 QuantizedSerdeType::Afq => {
394 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
395 .pack_factor(dtype)
396 }
397 QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
398 QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
399 };
400 total_pack_factors += pack_factor;
401 }
402
403 total_pack_factors / total_tensors
404 };
405
406 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
407 &config,
408 dtype,
409 weight_pack_factor,
410 None,
411 )?;
412 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
413 &config,
414 dtype,
415 weight_pack_factor,
416 None,
417 )?;
418 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
419 (
420 layer_sizes_in_bytes,
421 non_mapped_size_in_bytes,
422 layer_sizes_sum + non_mapped_size_in_bytes,
423 )
424 } else if let Some(isq) = in_situ_quant {
425 let weight_pack_factor = isq.pack_factor(dtype);
426 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
427 &config,
428 dtype,
429 weight_pack_factor,
430 None,
431 )?;
432 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
433 &config,
434 dtype,
435 weight_pack_factor,
436 None,
437 )?;
438 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
439 (
440 layer_sizes_in_bytes,
441 non_mapped_size_in_bytes,
442 layer_sizes_sum + non_mapped_size_in_bytes,
443 )
444 } else {
445 let weight_pack_factor =
447 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
448 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
449 &config,
450 dtype,
451 weight_pack_factor,
452 None,
453 )?;
454 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
455 &config,
456 dtype,
457 weight_pack_factor,
458 None,
459 )?;
460 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
461 (
462 layer_sizes_in_bytes,
463 non_mapped_size_in_bytes,
464 layer_sizes_sum + non_mapped_size_in_bytes,
465 )
466 };
467
468 let new = auto_device_map::get_device_layers(
469 &*self.inner,
470 &config,
471 self.inner.num_layers(&config)?,
472 layer_sizes_in_bytes,
473 non_mapped_size_in_bytes,
474 total_model_size_in_bytes,
475 &available_devices,
476 dtype,
477 ¶ms,
478 paged_attn_config.as_ref(),
479 )?;
480 mapper = DeviceMapSetting::Map(new);
481 }
482
483 let pipeline_mapper = mapper.into_mapper(
484 self.inner.num_layers(&config)?,
485 &device,
486 self.config.topology.as_ref(),
487 &available_devices,
488 )?;
489 let mapper = mapper.into_mapper(
490 self.inner.num_layers(&config)?,
491 &device,
492 self.config.topology.as_ref(),
493 &available_devices,
494 )?;
495 let mut layer_devices = Vec::new();
496 for layer in 0..self.inner.num_layers(&config)? {
497 let device = mapper.device_for(layer, false).cloned();
498 layer_devices.push(device);
499 }
500 let dtype = mapper.get_min_dtype(dtype)?;
501
502 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
505 if mapping_uses_cpu && paged_attn_config.is_some() {
506 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
507 paged_attn_config = None;
508 }
509
510 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
511 if crate::using_flash_attn() {
512 once_log_info("FlashAttention is enabled.");
513 }
514
515 let topology_overrides = self
516 .config
517 .topology
518 .as_ref()
519 .map(|topology| {
520 topology
521 .pattern_overrides()
522 .into_iter()
523 .map(|(regex, layer)| ImmediateIsqOverride {
524 predicate: regex,
525 ty: layer.isq,
526 device: layer.device.clone(),
527 })
528 .collect::<Vec<_>>()
529 })
530 .unwrap_or_default();
531 let has_override_isq = topology_overrides
532 .iter()
533 .any(|override_entry| override_entry.ty.is_some());
534 let topology_requires_post_quant = self
535 .config
536 .topology
537 .as_ref()
538 .is_some_and(|topology| topology.requires_post_quantization());
539
540 let allow_immediate_cli = self.config.imatrix.is_none()
541 && self.config.calibration_file.is_none()
542 && in_situ_quant.is_some();
543
544 let mut immediate_ty = None;
545 let mut immediate_predicates = Vec::new();
546 if allow_immediate_cli {
547 immediate_ty = in_situ_quant;
548 immediate_predicates =
549 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
550 self.inner.immediate_isq_predicates_moqe(&config)?
551 } else {
552 self.inner.immediate_isq_predicates(&config)?
553 };
554 info!("Applying ISQ to {in_situ_quant:?}");
555 if immediate_predicates.is_empty() {
556 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
557 }
558 }
559
560 let use_immediate = allow_immediate_cli || has_override_isq;
561 if use_immediate {
562 let (pool, num_threads) = mistralrs_quant::create_isq_thread_pool(immediate_ty);
563 info!("Applying immediate ISQ in parallel on {num_threads} threads.");
564 mistralrs_quant::set_immediate_isq_with_pool(
565 immediate_ty,
566 immediate_predicates.clone(),
567 topology_overrides.clone(),
568 pool,
569 );
570 }
571
572 let mut loading_isq = if use_immediate {
574 false
575 } else {
576 in_situ_quant.is_some()
577 };
578 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
579 loading_isq = true;
580 }
581 loading_isq |= topology_requires_post_quant;
582 loading_isq |= self.config.from_uqff.is_some();
583
584 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
585 anyhow::bail!(
586 "`imatrix` and `calibration_file` were both specified, this is not allowed."
587 );
588 }
589
590 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
596 loading_isq = false;
597 if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
598 Device::Cpu
599 } else {
600 device.clone()
601 }
602 } else {
603 Device::Cpu
604 };
605
606 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
607
608 let attention_mechanism = if paged_attn_config.is_some() {
609 AttentionImplementation::PagedAttention
610 } else {
611 AttentionImplementation::Eager
612 };
613
614 let multi_progress = Arc::new(new_multi_progress());
615
616 let matformer_slicing_config = if let Some(matformer_path) =
618 &self.config.matformer_config_path
619 {
620 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
621 info!("Loading Matformer config from {:?}", matformer_path);
622 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
623
624 if let Some(slice_name) = &self.config.matformer_slice_name {
625 info!("Using Matformer slice: {}", slice_name);
626 Some(MatformerSliceConfig::new(slice_name.clone(), config))
627 } else {
628 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
631 None
632 }
633 } else {
634 None
635 };
636
637 let mut model = if use_nccl || cfg!(feature = "ring") {
638 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
639 dtype,
640 &device,
641 &available_devices,
642 silent,
643 &config,
644 loading_isq,
645 self.config.from_uqff.is_some(),
646 self.config.organization,
647 &*self.inner,
648 paths.as_ref(),
649 )?;
650
651 match self.kind {
653 ModelKind::Normal => normal_model_loader_sharded!(
654 sharded_vb,
655 config,
656 self.inner,
657 mapper,
658 loading_isq,
659 device.clone(),
660 attention_mechanism,
661 multi_progress.clone(),
662 matformer_slicing_config.clone(),
663 ),
664 ModelKind::Adapter {
665 adapter: AdapterKind::XLora,
666 } => xlora_model_loader!(
667 paths,
668 Some(dtype),
669 &load_device,
670 layer_devices.clone(),
671 config,
672 self.inner,
673 silent,
674 mapper,
675 loading_isq,
676 device.clone(),
677 multi_progress.clone(),
678 matformer_slicing_config.clone(),
679 ),
680 ModelKind::Adapter {
681 adapter: AdapterKind::Lora,
682 } => lora_model_loader!(
683 paths,
684 Some(dtype),
685 &load_device,
686 layer_devices.clone(),
687 config,
688 self.inner,
689 silent,
690 mapper,
691 loading_isq,
692 self.config.from_uqff.is_some(),
693 device.clone(),
694 attention_mechanism,
695 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
696 multi_progress.clone(),
697 matformer_slicing_config.clone(),
698 ),
699 _ => unreachable!(),
700 }
701 } else {
702 match self.kind {
703 ModelKind::Normal => normal_model_loader!(
704 paths,
705 Some(dtype),
706 &load_device,
707 layer_devices.clone(),
708 config,
709 self.inner,
710 silent,
711 mapper,
712 loading_isq,
713 self.config.from_uqff.is_some(),
714 device.clone(),
715 attention_mechanism,
716 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
717 multi_progress.clone(),
718 matformer_slicing_config.clone(),
719 ),
720 ModelKind::Adapter {
721 adapter: AdapterKind::XLora,
722 } => xlora_model_loader!(
723 paths,
724 Some(dtype),
725 &load_device,
726 layer_devices.clone(),
727 config,
728 self.inner,
729 silent,
730 mapper,
731 loading_isq,
732 device.clone(),
733 multi_progress.clone(),
734 matformer_slicing_config.clone(),
735 ),
736 ModelKind::Adapter {
737 adapter: AdapterKind::Lora,
738 } => lora_model_loader!(
739 paths,
740 Some(dtype),
741 &load_device,
742 layer_devices.clone(),
743 config,
744 self.inner,
745 silent,
746 mapper,
747 loading_isq,
748 self.config.from_uqff.is_some(),
749 device.clone(),
750 attention_mechanism,
751 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
752 multi_progress.clone(),
753 matformer_slicing_config.clone(),
754 ),
755 _ => unreachable!(),
756 }
757 };
758
759 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
760 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
761 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
762 Ok(conf) => Some(conf),
763 Err(e) => {
764 warn!("Failed to parse generation_config.json: {}", e);
765 None
766 }
767 }
768 });
769
770 let chat_template_explicit = paths
771 .get_chat_template_explicit()
772 .as_ref()
773 .map(|x| x.to_string_lossy().to_string());
774 let chat_template = get_chat_template(
775 paths,
776 self.jinja_explicit.as_ref(),
777 chat_template_explicit.as_ref(),
778 self.chat_template.as_ref(),
779 None,
780 );
781
782 if let Some(calibration_file) = &self.config.calibration_file {
783 let calibration_data = std::fs::read_to_string(calibration_file)?;
784 let tokens = tokenizer
786 .encode_fast(calibration_data, false)
787 .map_err(anyhow::Error::msg)?
788 .get_ids()
789 .to_vec();
790 info!(
791 "Collecting imatrix from calibration file `{}` of {} tokens.",
792 calibration_file.display(),
793 tokens.len()
794 );
795 let bos_tok_id = chat_template
796 .bos_tok()
797 .as_deref()
798 .and_then(|tok| tokenizer.token_to_id(tok));
799
800 match self.config.organization {
801 IsqOrganization::Default => model.begin_track_stats()?,
802 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
803 }
804
805 const CHUNK_SIZE: usize = 1024;
806 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
807 let start = Instant::now();
808 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
809 let mut chunk = chunk.to_vec();
810 if let Some(bos_tok_id) = bos_tok_id {
811 chunk.insert(0, bos_tok_id);
812 }
813 let chunk_len = chunk.len();
814
815 let start = Instant::now();
816 let inputs = make_prompt_chunk(
817 0,
818 vec![&chunk],
819 &[0],
820 &load_device,
821 None,
822 false,
823 None,
824 Some(pipeline_mapper.as_ref()),
825 None,
826 model.config().sliding_window,
827 )?;
828
829 model.forward(
830 &inputs.input.to_device(model.device())?,
831 &inputs.positions,
832 inputs.context_lens.clone(),
833 inputs.position_ids.clone(),
834 None,
835 &inputs.flash_meta.clone(),
836 )?;
837
838 match model.cache_mut() {
839 EitherCache::Full(full) => {
840 for layer in &mut *full.lock() {
841 *layer = None
842 }
843 }
844 EitherCache::Normal(normal) => {
845 for layer in &mut *normal.lock().unwrap().0 {
846 layer.reset();
847 }
848 }
849 EitherCache::Hybrid(hybrid) => {
850 hybrid.lock().unwrap().reset();
851 }
852 }
853
854 let end = Instant::now();
855 info!(
856 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
857 i + 1,
858 end.duration_since(start).as_secs_f32()
859 );
860 }
861 load_device.synchronize()?;
862 let end = Instant::now();
863 info!(
864 "Finished collecting imatrix in {:.2}s",
865 end.duration_since(start).as_secs_f32()
866 );
867 }
868
869 let should_serialize = self.config.write_uqff.is_some();
871 let should_quantize_pass = loading_isq;
872
873 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
874 let imatrix_source = if should_quantize_pass {
875 match (
876 self.config.imatrix.as_ref(),
877 self.config.calibration_file.is_some(),
878 ) {
879 (None, false) => None,
880 (Some(file), false) => Some(ImatrixDataSource::File(file)),
881 (None, true) => Some(ImatrixDataSource::Collected),
882 (Some(_), true) => unreachable!(),
883 }
884 } else {
885 None
886 };
887
888 if should_quantize_pass {
889 info!("Applying ISQ to all ranks.");
890 } else {
891 info!("Serializing existing ISQ tensors without additional quantization.");
892 }
893
894 let multi_progress = Arc::new(new_multi_progress());
895
896 model.quantize(
897 in_situ_quant,
898 model.device().clone(),
899 self.config.topology.as_ref(),
900 silent,
901 imatrix_source,
902 self.config.organization,
903 should_quantize_pass,
904 self.config.write_uqff.as_ref(),
905 UqffFullSer {
906 tokenizer: &tokenizer,
907 template_filename: paths.get_template_filename(),
908 generation_config: paths.get_gen_conf_filename(),
909 config: config.clone(),
910 processor_filename: &None,
911 preprocessor_filename: &None,
912 modules: None,
913 module_paths: None,
914 },
915 multi_progress.clone(),
916 )?;
917 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
918 model.load_from_artifacts(
919 device.clone(),
920 self.config.topology.as_ref(),
921 silent,
922 from_uqff,
923 )?;
924 }
925
926 let paged_attn_config = if matches!(
927 self.kind,
928 ModelKind::Adapter {
929 adapter: AdapterKind::XLora
930 }
931 ) {
932 warn!(
933 "Adapter parallel_models do not currently support PagedAttention, running without"
934 );
935 None
936 } else {
937 paged_attn_config
938 };
939
940 let model_metadata = model.model_config();
941 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
942 let cache_config = calculate_cache_config(
943 paged_attn_config.mem_gpu,
944 paged_attn_config.block_size,
945 dtype,
946 paged_attn_config.cache_type,
947 model_metadata.as_ref(),
948 &device,
949 &pipeline_mapper
950 .get_unique_devices()
951 .into_iter()
952 .map(Some)
953 .collect::<Vec<_>>(),
954 silent,
955 None,
956 max_kv_tokens,
957 )?;
958
959 let mut layer_devices = Vec::new();
960 for layer in 0..self.inner.num_layers(&config)? {
961 let device = model.get_layers().1.device_for(layer, false).cloned();
962 layer_devices.push(device);
963 }
964 let cache_engine = CacheEngine::new(
965 model_metadata.as_ref(),
966 &cache_config,
967 dtype,
968 model.device(),
969 layer_devices.clone(),
970 )?;
971
972 (Some(cache_config), Some(cache_engine))
973 } else {
974 (None, None)
975 };
976
977 let max_seq_len = model.max_seq_len();
978 let llg_factory = build_llg_factory(tokenizer.clone())?;
979 let num_hidden_layers = match model.cache() {
980 EitherCache::Full(full) => full.lock().len(),
981 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
982 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
983 };
984 let generation_defaults = gen_conf
985 .as_ref()
986 .and_then(GenerationConfig::generation_defaults);
987 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
988 let sliding_window = model.config().sliding_window;
989 Ok(Arc::new(Mutex::new(NormalPipeline {
990 model,
991 tokenizer: tokenizer.into(),
992 no_kv_cache: self.no_kv_cache,
993 chat_template: Arc::new(chat_template),
994 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
995 NonGranularState {
996 non_granular_index: Arc::new(Mutex::new(0)),
997 tgt_non_granular_index,
998 }
999 }),
1000 model_id: self.model_id.clone(),
1001 metadata: Arc::new(GeneralMetadata {
1002 max_seq_len,
1003 llg_factory: Some(llg_factory),
1004 no_kv_cache: self.no_kv_cache,
1005 no_prefix_cache: is_xlora,
1006 num_hidden_layers,
1007 eos_tok: eos,
1008 kind: self.kind.clone(),
1009 is_xlora,
1010 activation_dtype: dtype,
1011 sliding_window,
1012 cache_config,
1013 cache_engine,
1014 model_metadata: Some(model_metadata),
1015 modalities: Modalities {
1016 input: vec![SupportedModality::Text],
1017 output: vec![SupportedModality::Text],
1018 },
1019 }),
1020 topology: self.config.topology.clone(),
1021 silent,
1022 organization: self.config.organization,
1023 template_filename: paths.get_template_filename().clone(),
1024 generation_config: paths.get_gen_conf_filename().cloned(),
1025 generation_defaults,
1026 config,
1027 imatrix: self.config.imatrix.clone(),
1028 mapper: pipeline_mapper,
1029 })))
1030 }
1031
1032 fn get_id(&self) -> String {
1033 self.model_id.clone()
1034 }
1035
1036 fn get_kind(&self) -> ModelKind {
1037 self.kind.clone()
1038 }
1039}
1040
1041impl PreProcessingMixin for NormalPipeline {
1042 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1043 Some(self.chat_template.clone())
1044 }
1045 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1046 None
1047 }
1048}
1049
1050impl IsqPipelineMixin for NormalPipeline {
1051 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1052 let device = self.device().clone();
1053 let multi_progress = Arc::new(new_multi_progress());
1054 self.model.quantize(
1055 Some(dtype),
1056 device.clone(),
1057 self.topology.as_ref(),
1058 self.silent,
1059 self.imatrix.as_ref().map(ImatrixDataSource::File),
1060 self.organization,
1061 true,
1062 None,
1063 UqffFullSer {
1064 tokenizer: &self.tokenizer,
1065 template_filename: &self.template_filename,
1066 generation_config: self.generation_config.as_ref(),
1067 config: self.config.clone(),
1068 processor_filename: &None,
1069 preprocessor_filename: &None,
1070 modules: None,
1071 module_paths: None,
1072 },
1073 multi_progress.clone(),
1074 )?;
1075 Ok(())
1076 }
1077}
1078
1079impl CacheManagerMixin for NormalPipeline {
1080 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1081 match self.model.cache() {
1082 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1083 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1084 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1085 }
1086 }
1087 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1088 match self.model.cache() {
1089 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1090 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1091 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1092 }
1093 }
1094 fn set_none_cache(
1095 &self,
1096 seqs: &mut [&mut Sequence],
1097 reset_non_granular: bool,
1098 modify_draft_cache: bool,
1099 load_preallocated_cache: bool,
1100 ) {
1101 match self.model.cache() {
1102 EitherCache::Full(_) => {
1103 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1104 }
1105 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1106 self,
1107 seqs,
1108 modify_draft_cache,
1109 load_preallocated_cache,
1110 ),
1111 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1112 self,
1113 seqs,
1114 modify_draft_cache,
1115 load_preallocated_cache,
1116 ),
1117 }
1118 if reset_non_granular {
1119 self.reset_non_granular_state()
1120 }
1121 }
1122 fn cache(&self) -> &EitherCache {
1123 self.model.cache()
1124 }
1125}
1126
1127impl MetadataMixin for NormalPipeline {
1128 fn device(&self) -> Device {
1129 self.model.device().clone()
1130 }
1131 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1132 Some(self.tokenizer.clone())
1133 }
1134 fn name(&self) -> String {
1135 self.model_id.clone()
1136 }
1137 fn reset_non_granular_state(&self) {
1138 if let Some(s) = self.non_granular_state.as_ref() {
1139 *self.cache().full().get_scalings_cache() = None;
1140 *get_mut_arcmutex!(s.non_granular_index) = 0;
1141 }
1142 }
1143 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1144 self.metadata.clone()
1145 }
1146 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1147 self.generation_defaults.clone()
1148 }
1149 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1150 Some(&*self.mapper)
1151 }
1152}
1153
1154#[async_trait::async_trait]
1155impl Pipeline for NormalPipeline {
1156 fn forward_inputs(
1157 &mut self,
1158 inputs: Box<dyn Any>,
1159 return_raw_logits: bool,
1160 ) -> Result<ForwardInputsResult, candle_core::Error> {
1161 let ModelInputs {
1162 input_ids,
1163 input_ids_full,
1164 seqlen_offsets,
1165 seqlen_offsets_full,
1166 context_lens,
1167 position_ids,
1168 paged_attn_meta,
1169 flash_meta,
1170 flash_meta_full,
1171 } = *inputs.downcast().expect("Downcast failed.");
1172 let metadata = self.get_metadata();
1173 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1174 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1175 (Some(_), None) => {
1176 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1178 }
1179 (None, Some(_)) => {
1180 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1182 }
1183 (None, None) => None,
1184 };
1185 let logits = match self.model.is_xlora() {
1186 false => {
1187 let paged_attn_meta = paged_attn_meta
1188 .as_ref()
1189 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1190
1191 self.model.forward(
1192 &input_ids,
1193 &seqlen_offsets,
1194 context_lens,
1195 position_ids,
1196 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1197 &flash_meta,
1198 )?
1199 }
1200 true => self.model.xlora_forward(
1201 &input_ids,
1202 input_ids_full.as_ref().unwrap_or(&input_ids),
1203 &seqlen_offsets,
1204 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1205 self.no_kv_cache,
1206 &self.non_granular_state,
1207 context_lens,
1208 position_ids,
1209 &flash_meta,
1210 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1211 )?,
1212 };
1213 if return_raw_logits {
1214 Ok(ForwardInputsResult::RawLogits { logits })
1215 } else {
1216 Ok(ForwardInputsResult::CausalGeneration { logits })
1217 }
1218 }
1219 async fn sample_causal_gen(
1220 &self,
1221 seqs: &mut [&mut Sequence],
1222 logits: Vec<Tensor>,
1223 prefix_cacher: &mut PrefixCacheManagerV2,
1224 disable_eos_stop: bool,
1225 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1226 ) -> Result<(), candle_core::Error> {
1227 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1228 }
1229 fn category(&self) -> ModelCategory {
1230 ModelCategory::Text
1231 }
1232}
1233
1234impl AnyMoePipelineMixin for NormalPipeline {
1235 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1236 self.model.finish_training(gate_model_id)
1237 }
1238 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1239 self.model.get_vars()
1240 }
1241 fn amoe_base_model_trainable_params(&self) -> usize {
1242 self.model.trainable_params()
1243 }
1244 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1245 self.model.take_cached_gating_outputs()
1246 }
1247 fn amoe_create_layers(
1248 &mut self,
1249 model_ids: Vec<String>,
1250 token: &TokenSource,
1251 revision: Option<String>,
1252 match_regex: &str,
1253 config: crate::amoe::AnyMoeConfig,
1254 dtype: candle_core::DType,
1255 dev: &Device,
1256 (prefix, mlp): (String, String),
1257 layers: Vec<usize>,
1258 expert_type: AnyMoeExpertType,
1259 silent: bool,
1260 gate_model_id: Option<String>,
1261 ) -> candle_core::Result<()> {
1262 let mut vbs = Vec::new();
1263 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1265 for model_id in model_ids {
1266 let model_id_str = &model_id;
1267 let model_id = Path::new(&model_id);
1268
1269 let api = {
1270 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1271 let mut api = ApiBuilder::from_cache(cache)
1272 .with_progress(!silent)
1273 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1274 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1275 api = api.with_cache_dir(cache_dir);
1276 }
1277 api.build().map_err(candle_core::Error::msg)?
1278 };
1279 let revision = revision.clone().unwrap_or("main".to_string());
1280 let api = api.repo(Repo::with_revision(
1281 model_id_str.clone(),
1282 RepoType::Model,
1283 revision.clone(),
1284 ));
1285
1286 let mut filenames = vec![];
1287 for rfilename in
1288 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1289 {
1290 filenames.push(api_get_file!(api, &rfilename, model_id));
1291 }
1292
1293 let regex = regex.clone();
1294 let match_regex_clone = match_regex.to_string();
1295 let layers_clone = layers.clone();
1296 let vb = from_mmaped_safetensors(
1297 filenames,
1298 vec![],
1299 Some(dtype),
1300 dev,
1301 vec![None],
1302 silent,
1303 None,
1304 move |key| {
1305 if regex.is_match(&key) {
1306 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1309 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1310 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1311 .parse::<usize>()
1312 .unwrap();
1313 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1314 } else {
1315 false
1316 }
1317 },
1318 Arc::new(|_| DeviceForLoadTensor::Base),
1319 )?;
1320 vbs.push(vb);
1321 }
1322
1323 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1324 let model_id_str = &gate_model_id;
1325 let model_id = Path::new(&gate_model_id);
1326
1327 let api = {
1328 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1329 let mut api = ApiBuilder::from_cache(cache)
1330 .with_progress(!silent)
1331 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1332 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1333 api = api.with_cache_dir(cache_dir);
1334 }
1335 api.build().map_err(candle_core::Error::msg)?
1336 };
1337 let revision = revision.clone().unwrap_or("main".to_string());
1338 let api = api.repo(Repo::with_revision(
1339 model_id_str.clone(),
1340 RepoType::Model,
1341 revision.clone(),
1342 ));
1343
1344 let mut gate_filenames = vec![];
1345 for rfilename in
1346 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1347 {
1348 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1349 }
1350 assert_eq!(
1351 gate_filenames.len(),
1352 1,
1353 "Gate model ID must contain only one .safetensors file"
1354 );
1355
1356 let vb = from_mmaped_safetensors(
1357 gate_filenames.clone(),
1358 vec![],
1359 Some(dtype),
1360 dev,
1361 vec![None],
1362 silent,
1363 None,
1364 |_| true,
1365 Arc::new(|_| DeviceForLoadTensor::Base),
1366 )?;
1367 info!(
1368 "Loaded gating layers from `{}`",
1369 gate_filenames[0].display()
1370 );
1371 Some(vb)
1372 } else {
1373 None
1374 };
1375
1376 self.model.create_anymoe_layers(
1377 vbs.clone(),
1378 config.clone(),
1379 (prefix.clone(), mlp.clone()),
1380 layers.clone(),
1381 expert_type.clone(),
1382 gate_vb.clone(),
1383 )?;
1384
1385 Ok(())
1386 }
1387 fn amoe_supported(&self) -> bool {
1388 self.model.amoe_supported()
1389 }
1390}