1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
14 GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
15 MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16 Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::{UqffFullSer, WeightLoadingMode, WeightLoadingState};
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::{make_prompt_chunk, InputMetadata};
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38 progress::{new_multi_progress, ProgressScopeGuard},
39 tokens::get_token,
40 varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use hanzo_ml::{Device, Tensor, Var};
50use hanzo_quant::log::once_log_info;
51use hanzo_quant::{
52 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
53};
54use hf_hub::Cache;
55use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{debug, info, trace, warn};
68
69pub struct NormalPipeline {
70 model: Box<dyn NormalModel + Send + Sync>,
71 tokenizer: Arc<Tokenizer>,
72 no_kv_cache: bool,
73 chat_template: Arc<ChatTemplate>,
74 non_granular_state: Option<NonGranularState>,
75 model_id: String,
76 metadata: Arc<GeneralMetadata>,
77 topology: Option<Topology>,
78 silent: bool,
79 organization: IsqOrganization,
80 template_filename: Option<PathBuf>,
82 generation_config: Option<PathBuf>,
83 generation_defaults: Option<crate::ModelGenerationDefaults>,
84 config: String,
85 imatrix: Option<PathBuf>,
86 mapper: Box<dyn DeviceMapper + Send + Sync>,
87}
88
89pub struct NormalLoader {
91 inner: Box<dyn NormalModelLoader>,
92 model_id: String,
93 config: NormalSpecificConfig,
94 xlora_model_id: Option<String>,
95 lora_adapter_ids: Option<Vec<String>>,
96 kind: ModelKind,
97 xlora_order: Option<Ordering>,
98 no_kv_cache: bool,
99 chat_template: Option<String>,
100 tokenizer_json: Option<String>,
101 tgt_non_granular_index: Option<usize>,
102 token_source: RwLock<Option<TokenSource>>,
103 revision: RwLock<Option<String>>,
104 from_uqff: RwLock<Option<Vec<PathBuf>>>,
105 jinja_explicit: Option<String>,
106 hf_cache_path: Option<PathBuf>,
107}
108
109#[derive(Default)]
110pub struct NormalLoaderBuilder {
112 model_id: Option<String>,
113 config: NormalSpecificConfig,
114 xlora_model_id: Option<String>,
115 lora_adapter_ids: Option<Vec<String>>,
116 kind: ModelKind,
117 xlora_order: Option<Ordering>,
118 no_kv_cache: bool,
119 chat_template: Option<String>,
120 tokenizer_json: Option<String>,
121 tgt_non_granular_index: Option<usize>,
122 jinja_explicit: Option<String>,
123 hf_cache_path: Option<PathBuf>,
124}
125
126#[derive(Clone, Default)]
127pub struct NormalSpecificConfig {
129 pub topology: Option<Topology>,
130 pub organization: IsqOrganization,
131 pub write_uqff: Option<PathBuf>,
132 pub from_uqff: Option<Vec<PathBuf>>,
133 pub imatrix: Option<PathBuf>,
134 pub calibration_file: Option<PathBuf>,
135 pub hf_cache_path: Option<PathBuf>,
136 pub matformer_config_path: Option<PathBuf>,
137 pub matformer_slice_name: Option<String>,
138}
139
140impl NormalLoaderBuilder {
141 pub fn new(
142 config: NormalSpecificConfig,
143 chat_template: Option<String>,
144 tokenizer_json: Option<String>,
145 model_id: Option<String>,
146 no_kv_cache: bool,
147 jinja_explicit: Option<String>,
148 ) -> Self {
149 Self {
150 config,
151 chat_template,
152 tokenizer_json,
153 model_id,
154 kind: ModelKind::Normal,
155 jinja_explicit,
156 no_kv_cache,
157 ..Default::default()
158 }
159 }
160
161 fn with_adapter(
162 mut self,
163 xlora_model_id: String,
164 xlora_order: Ordering,
165 no_kv_cache: bool,
166 tgt_non_granular_index: Option<usize>,
167 ) -> Self {
168 self.xlora_model_id = Some(xlora_model_id);
169 self.xlora_order = Some(xlora_order);
170 self.no_kv_cache = no_kv_cache;
171 self.tgt_non_granular_index = tgt_non_granular_index;
172 self.model_id = if let Some(id) = self.model_id {
173 Some(id)
174 } else {
175 info!(
176 "Using adapter base model ID: `{}`",
177 self.xlora_order.as_ref().unwrap().base_model_id
178 );
179 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
180 };
181 self
182 }
183
184 pub fn with_xlora(
185 mut self,
186 xlora_model_id: String,
187 xlora_order: Ordering,
188 no_kv_cache: bool,
189 tgt_non_granular_index: Option<usize>,
190 ) -> Self {
191 self.kind = ModelKind::Adapter {
192 adapter: AdapterKind::XLora,
193 };
194 self.with_adapter(
195 xlora_model_id,
196 xlora_order,
197 no_kv_cache,
198 tgt_non_granular_index,
199 )
200 }
201
202 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
203 self.kind = ModelKind::Adapter {
204 adapter: AdapterKind::Lora,
205 };
206 self.lora_adapter_ids = Some(lora_adapter_ids);
207 self
208 }
209
210 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
211 self.hf_cache_path = Some(hf_cache_path);
212 self
213 }
214
215 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
218 let loader: Box<dyn NormalModelLoader> = match loader_tp {
219 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
220 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
221 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
222 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
223 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
224 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
225 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
226 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
227 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
228 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
229 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
230 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
231 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
232 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
233 Some(NormalLoaderType::GLM4MoeLite) => Box::new(GLM4MoeLiteLoader),
234 Some(NormalLoaderType::GLM4Moe) => Box::new(GLM4MoeLoader),
235 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
236 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
237 Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
238 Some(NormalLoaderType::GptOss) => Box::new(GptOssLoader),
239 Some(NormalLoaderType::Qwen3Next) => Box::new(Qwen3NextLoader),
240 None => Box::new(AutoNormalLoader),
241 };
242 Ok(Box::new(NormalLoader {
243 inner: loader,
244 model_id: self.model_id.unwrap(),
245 config: self.config,
246 xlora_model_id: self.xlora_model_id,
247 lora_adapter_ids: self.lora_adapter_ids,
248 kind: self.kind,
249 xlora_order: self.xlora_order,
250 no_kv_cache: self.no_kv_cache,
251 chat_template: self.chat_template,
252 tokenizer_json: self.tokenizer_json,
253 tgt_non_granular_index: self.tgt_non_granular_index,
254 jinja_explicit: self.jinja_explicit,
255 token_source: RwLock::new(None),
256 revision: RwLock::new(None),
257 from_uqff: RwLock::new(None),
258 hf_cache_path: self.hf_cache_path,
259 }))
260 }
261}
262
263impl Loader for NormalLoader {
264 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
265 fn load_model_from_hf(
266 &self,
267 revision: Option<String>,
268 token_source: TokenSource,
269 dtype: &dyn TryIntoDType,
270 device: &Device,
271 silent: bool,
272 mapper: DeviceMapSetting,
273 in_situ_quant: Option<IsqType>,
274 paged_attn_config: Option<PagedAttentionConfig>,
275 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
276 let _progress_guard = ProgressScopeGuard::new(silent);
277 let cache = self
278 .hf_cache_path
279 .clone()
280 .map(Cache::new)
281 .unwrap_or_default();
282 GLOBAL_HF_CACHE.get_or_init(|| cache);
283
284 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
285 LocalModelPaths,
286 &token_source,
287 revision.clone(),
288 self,
289 None,
290 None,
291 silent,
292 self.config.from_uqff.is_some()
293 );
294 *self
295 .token_source
296 .write()
297 .expect("Failed to write to token source") = Some(token_source);
298 *self.revision.write().expect("Failed to write to revision") = revision.clone();
299 if let Some(from_uqff) = self.config.from_uqff.clone() {
300 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
301 }
302 self.load_model_from_path(
303 &paths?,
304 dtype,
305 device,
306 silent,
307 mapper,
308 in_situ_quant,
309 paged_attn_config,
310 )
311 }
312
313 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
314 fn load_model_from_path(
315 &self,
316 paths: &Box<dyn ModelPaths>,
317 dtype: &dyn TryIntoDType,
318 device: &Device,
319 silent: bool,
320 mut mapper: DeviceMapSetting,
321 in_situ_quant: Option<IsqType>,
322 mut paged_attn_config: Option<PagedAttentionConfig>,
323 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
324 let _progress_guard = ProgressScopeGuard::new(silent);
325 let config = std::fs::read_to_string(paths.get_config_filename())?;
326
327 if !self.inner.supports_paged_attention(&config)? {
328 paged_attn_config = None;
329 }
330
331 debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
332
333 let use_nccl = hanzo_quant::distributed::use_nccl();
334
335 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
336 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
337 let WorkerTransferData::Init { id: _, worker_rank } = payload;
338 vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
339 } else if use_nccl {
340 vec![hanzo_ml::Device::new_cuda(0)?]
341 } else {
342 device_map::get_all_similar_devices(device)?
343 };
344 #[cfg(feature = "cuda")]
345 for device in &available_devices {
346 if let Device::Cuda(dev) = device {
347 unsafe { dev.disable_event_tracking() };
348 }
349 }
350 let device = if use_nccl || cfg!(feature = "ring") {
351 available_devices[0].clone()
352 } else {
353 device.clone()
354 };
355
356 let mut max_kv_tokens: Option<usize> = None;
358 if use_nccl || cfg!(feature = "ring") {
359 mapper = DeviceMapSetting::DummyNccl {
360 nm_device: available_devices[0].clone(),
361 };
362 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
363 max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
364 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
366
367 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
370 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
371 let weight_pack_factor = {
372 let ser_artifacts =
373 unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
374 let mut total_pack_factors = 0;
375 let total_tensors = ser_artifacts.tensors().len();
376 for (_, artifact) in ser_artifacts.tensors() {
377 let artifact = artifact.data();
378 let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
380 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
381 {
382 QuantizedSerdeType::Hqq => {
383 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384 .pack_factor(dtype)
385 }
386 QuantizedSerdeType::Gguf => {
387 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
388 .pack_factor(dtype)
389 }
390 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
391 QuantizedSerdeType::Unquant => 1,
392 QuantizedSerdeType::Afq => {
393 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
394 .pack_factor(dtype)
395 }
396 QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
397 QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
398 };
399 total_pack_factors += pack_factor;
400 }
401
402 total_pack_factors / total_tensors
403 };
404
405 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
406 &config,
407 dtype,
408 weight_pack_factor,
409 None,
410 )?;
411 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
412 &config,
413 dtype,
414 weight_pack_factor,
415 None,
416 )?;
417 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
418 (
419 layer_sizes_in_bytes,
420 non_mapped_size_in_bytes,
421 layer_sizes_sum + non_mapped_size_in_bytes,
422 )
423 } else if let Some(isq) = in_situ_quant {
424 let weight_pack_factor = isq.pack_factor(dtype);
425 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
426 &config,
427 dtype,
428 weight_pack_factor,
429 None,
430 )?;
431 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
432 &config,
433 dtype,
434 weight_pack_factor,
435 None,
436 )?;
437 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
438 (
439 layer_sizes_in_bytes,
440 non_mapped_size_in_bytes,
441 layer_sizes_sum + non_mapped_size_in_bytes,
442 )
443 } else {
444 let weight_pack_factor =
446 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
447 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
448 &config,
449 dtype,
450 weight_pack_factor,
451 None,
452 )?;
453 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
454 &config,
455 dtype,
456 weight_pack_factor,
457 None,
458 )?;
459 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
460 (
461 layer_sizes_in_bytes,
462 non_mapped_size_in_bytes,
463 layer_sizes_sum + non_mapped_size_in_bytes,
464 )
465 };
466
467 let new = auto_device_map::get_device_layers(
468 &*self.inner,
469 &config,
470 self.inner.num_layers(&config)?,
471 layer_sizes_in_bytes,
472 non_mapped_size_in_bytes,
473 total_model_size_in_bytes,
474 &available_devices,
475 dtype,
476 ¶ms,
477 paged_attn_config.as_ref(),
478 )?;
479 mapper = DeviceMapSetting::Map(new);
480 }
481
482 let pipeline_mapper = mapper.into_mapper(
483 self.inner.num_layers(&config)?,
484 &device,
485 self.config.topology.as_ref(),
486 &available_devices,
487 )?;
488 let mapper = mapper.into_mapper(
489 self.inner.num_layers(&config)?,
490 &device,
491 self.config.topology.as_ref(),
492 &available_devices,
493 )?;
494 let mut layer_devices = Vec::new();
495 for layer in 0..self.inner.num_layers(&config)? {
496 let device = mapper.device_for(layer, false).cloned();
497 layer_devices.push(device);
498 }
499 let dtype = mapper.get_min_dtype(dtype)?;
500
501 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
504 if mapping_uses_cpu && paged_attn_config.is_some() {
505 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
506 paged_attn_config = None;
507 }
508
509 trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
510 if crate::using_flash_attn() {
511 once_log_info("FlashAttention is enabled.");
512 }
513
514 let topology_overrides = self
515 .config
516 .topology
517 .as_ref()
518 .map(|topology| {
519 topology
520 .pattern_overrides()
521 .into_iter()
522 .map(|(regex, layer)| ImmediateIsqOverride {
523 predicate: regex,
524 ty: layer.isq,
525 device: layer.device.clone(),
526 })
527 .collect::<Vec<_>>()
528 })
529 .unwrap_or_default();
530 let has_override_isq = topology_overrides
531 .iter()
532 .any(|override_entry| override_entry.ty.is_some());
533 let topology_requires_post_quant = self
534 .config
535 .topology
536 .as_ref()
537 .is_some_and(|topology| topology.requires_post_quantization());
538
539 let allow_immediate_cli = self.config.imatrix.is_none()
540 && self.config.calibration_file.is_none()
541 && in_situ_quant.is_some();
542
543 let mut immediate_ty = None;
544 let mut immediate_predicates = Vec::new();
545 if allow_immediate_cli {
546 immediate_ty = in_situ_quant;
547 immediate_predicates =
548 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
549 self.inner.immediate_isq_predicates_moqe(&config)?
550 } else {
551 self.inner.immediate_isq_predicates(&config)?
552 };
553 info!("Applying ISQ to {in_situ_quant:?}");
554 if immediate_predicates.is_empty() {
555 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
556 }
557 }
558
559 let use_immediate = allow_immediate_cli || has_override_isq;
560 if use_immediate {
561 let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
562 info!("Applying immediate ISQ in parallel on {num_threads} threads.");
563 hanzo_quant::set_immediate_isq_with_pool(
564 immediate_ty,
565 immediate_predicates.clone(),
566 topology_overrides.clone(),
567 pool,
568 );
569 }
570
571 let mut loading_isq = if use_immediate {
573 false
574 } else {
575 in_situ_quant.is_some()
576 };
577 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
578 loading_isq = true;
579 }
580 loading_isq |= topology_requires_post_quant;
581 loading_isq |= self.config.from_uqff.is_some();
582
583 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
584 anyhow::bail!(
585 "`imatrix` and `calibration_file` were both specified, this is not allowed."
586 );
587 }
588
589 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
595 loading_isq = false;
596 if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
597 Device::Cpu
598 } else {
599 device.clone()
600 }
601 } else {
602 Device::Cpu
603 };
604
605 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
606
607 let attention_mechanism = if paged_attn_config.is_some() {
608 AttentionImplementation::PagedAttention
609 } else {
610 AttentionImplementation::Eager
611 };
612
613 let multi_progress = Arc::new(new_multi_progress());
614
615 let matformer_slicing_config = if let Some(matformer_path) =
617 &self.config.matformer_config_path
618 {
619 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
620 info!("Loading Matformer config from {:?}", matformer_path);
621 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
622
623 if let Some(slice_name) = &self.config.matformer_slice_name {
624 info!("Using Matformer slice: {}", slice_name);
625 Some(MatformerSliceConfig::new(slice_name.clone(), config))
626 } else {
627 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
630 None
631 }
632 } else {
633 None
634 };
635
636 info!(
637 "{}",
638 WeightLoadingMode::from(WeightLoadingState {
639 from_uqff: self.config.from_uqff.is_some(),
640 loading_isq,
641 immediate_isq: use_immediate,
642 write_uqff: self.config.write_uqff.is_some(),
643 })
644 .message("model")
645 );
646
647 let mut model = if use_nccl || cfg!(feature = "ring") {
648 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
649 dtype,
650 &device,
651 &available_devices,
652 silent,
653 &config,
654 loading_isq,
655 self.config.from_uqff.is_some(),
656 self.config.organization,
657 &*self.inner,
658 paths.as_ref(),
659 )?;
660
661 match self.kind {
663 ModelKind::Normal => normal_model_loader_sharded!(
664 sharded_vb,
665 config,
666 self.inner,
667 mapper,
668 loading_isq,
669 device.clone(),
670 attention_mechanism,
671 multi_progress.clone(),
672 matformer_slicing_config.clone(),
673 ),
674 ModelKind::Adapter {
675 adapter: AdapterKind::XLora,
676 } => xlora_model_loader!(
677 paths,
678 Some(dtype),
679 &load_device,
680 layer_devices.clone(),
681 config,
682 self.inner,
683 silent,
684 mapper,
685 loading_isq,
686 device.clone(),
687 multi_progress.clone(),
688 matformer_slicing_config.clone(),
689 ),
690 ModelKind::Adapter {
691 adapter: AdapterKind::Lora,
692 } => lora_model_loader!(
693 paths,
694 Some(dtype),
695 &load_device,
696 layer_devices.clone(),
697 config,
698 self.inner,
699 silent,
700 mapper,
701 loading_isq,
702 self.config.from_uqff.is_some(),
703 device.clone(),
704 attention_mechanism,
705 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
706 multi_progress.clone(),
707 matformer_slicing_config.clone(),
708 ),
709 _ => unreachable!(),
710 }
711 } else {
712 match self.kind {
713 ModelKind::Normal => normal_model_loader!(
714 paths,
715 Some(dtype),
716 &load_device,
717 layer_devices.clone(),
718 config,
719 self.inner,
720 silent,
721 mapper,
722 loading_isq,
723 self.config.from_uqff.is_some(),
724 device.clone(),
725 attention_mechanism,
726 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
727 multi_progress.clone(),
728 matformer_slicing_config.clone(),
729 ),
730 ModelKind::Adapter {
731 adapter: AdapterKind::XLora,
732 } => xlora_model_loader!(
733 paths,
734 Some(dtype),
735 &load_device,
736 layer_devices.clone(),
737 config,
738 self.inner,
739 silent,
740 mapper,
741 loading_isq,
742 device.clone(),
743 multi_progress.clone(),
744 matformer_slicing_config.clone(),
745 ),
746 ModelKind::Adapter {
747 adapter: AdapterKind::Lora,
748 } => lora_model_loader!(
749 paths,
750 Some(dtype),
751 &load_device,
752 layer_devices.clone(),
753 config,
754 self.inner,
755 silent,
756 mapper,
757 loading_isq,
758 self.config.from_uqff.is_some(),
759 device.clone(),
760 attention_mechanism,
761 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
762 multi_progress.clone(),
763 matformer_slicing_config.clone(),
764 ),
765 _ => unreachable!(),
766 }
767 };
768
769 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
770 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
771 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
772 Ok(conf) => Some(conf),
773 Err(e) => {
774 warn!("Failed to parse generation_config.json: {}", e);
775 None
776 }
777 }
778 });
779
780 let chat_template_explicit = paths
781 .get_chat_template_explicit()
782 .as_ref()
783 .map(|x| x.to_string_lossy().to_string());
784 let chat_template = get_chat_template(
785 paths,
786 self.jinja_explicit.as_ref(),
787 chat_template_explicit.as_ref(),
788 self.chat_template.as_ref(),
789 None,
790 );
791
792 if let Some(calibration_file) = &self.config.calibration_file {
793 let calibration_data = std::fs::read_to_string(calibration_file)?;
794 let tokens = tokenizer
796 .encode_fast(calibration_data, false)
797 .map_err(anyhow::Error::msg)?
798 .get_ids()
799 .to_vec();
800 info!(
801 "Collecting imatrix from calibration file `{}` of {} tokens.",
802 calibration_file.display(),
803 tokens.len()
804 );
805 let bos_tok_id = chat_template
806 .bos_tok()
807 .as_deref()
808 .and_then(|tok| tokenizer.token_to_id(tok));
809
810 match self.config.organization {
811 IsqOrganization::Default => model.begin_track_stats()?,
812 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
813 }
814
815 const CHUNK_SIZE: usize = 1024;
816 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
817 let start = Instant::now();
818 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
819 let mut chunk = chunk.to_vec();
820 if let Some(bos_tok_id) = bos_tok_id {
821 chunk.insert(0, bos_tok_id);
822 }
823 let chunk_len = chunk.len();
824
825 let start = Instant::now();
826 let inputs = make_prompt_chunk(
827 0,
828 vec![&chunk],
829 &[0],
830 &load_device,
831 None,
832 false,
833 None,
834 Some(pipeline_mapper.as_ref()),
835 None,
836 model.config().sliding_window,
837 )?;
838
839 model.forward(
840 &inputs.input.to_device(model.device())?,
841 &inputs.positions,
842 inputs.context_lens.clone(),
843 inputs.position_ids.clone(),
844 None,
845 &inputs.flash_meta.clone(),
846 )?;
847
848 match model.cache_mut() {
849 EitherCache::Full(full) => {
850 for layer in &mut *full.lock() {
851 *layer = None
852 }
853 }
854 EitherCache::Normal(normal) => {
855 for layer in &mut *normal.lock().unwrap().0 {
856 layer.reset();
857 }
858 }
859 EitherCache::Hybrid(hybrid) => {
860 hybrid.lock().unwrap().reset();
861 }
862 }
863
864 let end = Instant::now();
865 info!(
866 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
867 i + 1,
868 end.duration_since(start).as_secs_f32()
869 );
870 }
871 load_device.synchronize()?;
872 let end = Instant::now();
873 info!(
874 "Finished collecting imatrix in {:.2}s",
875 end.duration_since(start).as_secs_f32()
876 );
877 }
878
879 let should_serialize = self.config.write_uqff.is_some();
881 let should_quantize_pass = loading_isq;
882
883 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
884 let imatrix_source = if should_quantize_pass {
885 match (
886 self.config.imatrix.as_ref(),
887 self.config.calibration_file.is_some(),
888 ) {
889 (None, false) => None,
890 (Some(file), false) => Some(ImatrixDataSource::File(file)),
891 (None, true) => Some(ImatrixDataSource::Collected),
892 (Some(_), true) => unreachable!(),
893 }
894 } else {
895 None
896 };
897
898 if should_quantize_pass {
899 debug!("Applying ISQ to all ranks.");
900 } else {
901 debug!("Serializing existing ISQ tensors without additional quantization.");
902 }
903
904 let multi_progress = Arc::new(new_multi_progress());
905
906 model.quantize(
907 in_situ_quant,
908 model.device().clone(),
909 self.config.topology.as_ref(),
910 silent,
911 imatrix_source,
912 self.config.organization,
913 should_quantize_pass,
914 self.config.write_uqff.as_ref(),
915 UqffFullSer {
916 tokenizer: &tokenizer,
917 template_filename: paths.get_template_filename(),
918 generation_config: paths.get_gen_conf_filename(),
919 config: config.clone(),
920 processor_filename: &None,
921 preprocessor_filename: &None,
922 modules: None,
923 module_paths: None,
924 },
925 multi_progress.clone(),
926 )?;
927 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
928 model.load_from_artifacts(
929 device.clone(),
930 self.config.topology.as_ref(),
931 silent,
932 from_uqff,
933 )?;
934 }
935
936 let paged_attn_config = if matches!(
937 self.kind,
938 ModelKind::Adapter {
939 adapter: AdapterKind::XLora
940 }
941 ) {
942 warn!(
943 "Adapter parallel_models do not currently support PagedAttention, running without"
944 );
945 None
946 } else {
947 paged_attn_config
948 };
949
950 let model_metadata = model.model_config();
951 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
952 let cache_config = calculate_cache_config(
953 paged_attn_config.mem_gpu,
954 paged_attn_config.block_size,
955 dtype,
956 paged_attn_config.cache_type,
957 model_metadata.as_ref(),
958 &device,
959 &pipeline_mapper
960 .get_unique_devices()
961 .into_iter()
962 .map(Some)
963 .collect::<Vec<_>>(),
964 silent,
965 None,
966 max_kv_tokens,
967 )?;
968
969 let mut layer_devices = Vec::new();
970 for layer in 0..self.inner.num_layers(&config)? {
971 let device = model.get_layers().1.device_for(layer, false).cloned();
972 layer_devices.push(device);
973 }
974 let cache_engine = CacheEngine::new(
975 model_metadata.as_ref(),
976 &cache_config,
977 dtype,
978 model.device(),
979 layer_devices.clone(),
980 )?;
981
982 (Some(cache_config), Some(cache_engine))
983 } else {
984 (None, None)
985 };
986
987 let max_seq_len = model.max_seq_len();
988 let llg_factory = build_llg_factory(tokenizer.clone())?;
989 let num_hidden_layers = match model.cache() {
990 EitherCache::Full(full) => full.lock().len(),
991 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
992 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
993 };
994 let generation_defaults = gen_conf
995 .as_ref()
996 .and_then(GenerationConfig::generation_defaults);
997 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
998 let sliding_window = model.config().sliding_window;
999 Ok(Arc::new(Mutex::new(NormalPipeline {
1000 model,
1001 tokenizer: tokenizer.into(),
1002 no_kv_cache: self.no_kv_cache,
1003 chat_template: Arc::new(chat_template),
1004 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
1005 NonGranularState {
1006 non_granular_index: Arc::new(Mutex::new(0)),
1007 tgt_non_granular_index,
1008 }
1009 }),
1010 model_id: self.model_id.clone(),
1011 metadata: Arc::new(GeneralMetadata {
1012 max_seq_len,
1013 llg_factory: Some(llg_factory),
1014 no_kv_cache: self.no_kv_cache,
1015 no_prefix_cache: is_xlora,
1016 num_hidden_layers,
1017 eos_tok: eos,
1018 kind: self.kind.clone(),
1019 is_xlora,
1020 activation_dtype: dtype,
1021 sliding_window,
1022 cache_config,
1023 cache_engine,
1024 model_metadata: Some(model_metadata),
1025 modalities: Modalities {
1026 input: vec![SupportedModality::Text],
1027 output: vec![SupportedModality::Text],
1028 },
1029 }),
1030 topology: self.config.topology.clone(),
1031 silent,
1032 organization: self.config.organization,
1033 template_filename: paths.get_template_filename().clone(),
1034 generation_config: paths.get_gen_conf_filename().cloned(),
1035 generation_defaults,
1036 config,
1037 imatrix: self.config.imatrix.clone(),
1038 mapper: pipeline_mapper,
1039 })))
1040 }
1041
1042 fn get_id(&self) -> String {
1043 self.model_id.clone()
1044 }
1045
1046 fn get_kind(&self) -> ModelKind {
1047 self.kind.clone()
1048 }
1049}
1050
1051impl PreProcessingMixin for NormalPipeline {
1052 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1053 Some(self.chat_template.clone())
1054 }
1055 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1056 None
1057 }
1058}
1059
1060impl IsqPipelineMixin for NormalPipeline {
1061 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1062 let device = self.device().clone();
1063 let multi_progress = Arc::new(new_multi_progress());
1064 self.model.quantize(
1065 Some(dtype),
1066 device.clone(),
1067 self.topology.as_ref(),
1068 self.silent,
1069 self.imatrix.as_ref().map(ImatrixDataSource::File),
1070 self.organization,
1071 true,
1072 None,
1073 UqffFullSer {
1074 tokenizer: &self.tokenizer,
1075 template_filename: &self.template_filename,
1076 generation_config: self.generation_config.as_ref(),
1077 config: self.config.clone(),
1078 processor_filename: &None,
1079 preprocessor_filename: &None,
1080 modules: None,
1081 module_paths: None,
1082 },
1083 multi_progress.clone(),
1084 )?;
1085 Ok(())
1086 }
1087}
1088
1089impl CacheManagerMixin for NormalPipeline {
1090 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1091 match self.model.cache() {
1092 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1093 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1094 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1095 }
1096 }
1097 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1098 match self.model.cache() {
1099 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1100 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1101 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1102 }
1103 }
1104 fn set_none_cache(
1105 &self,
1106 seqs: &mut [&mut Sequence],
1107 reset_non_granular: bool,
1108 modify_draft_cache: bool,
1109 load_preallocated_cache: bool,
1110 ) {
1111 match self.model.cache() {
1112 EitherCache::Full(_) => {
1113 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1114 }
1115 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1116 self,
1117 seqs,
1118 modify_draft_cache,
1119 load_preallocated_cache,
1120 ),
1121 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1122 self,
1123 seqs,
1124 modify_draft_cache,
1125 load_preallocated_cache,
1126 ),
1127 }
1128 if reset_non_granular {
1129 self.reset_non_granular_state()
1130 }
1131 }
1132 fn cache(&self) -> &EitherCache {
1133 self.model.cache()
1134 }
1135}
1136
1137impl MetadataMixin for NormalPipeline {
1138 fn device(&self) -> Device {
1139 self.model.device().clone()
1140 }
1141 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1142 Some(self.tokenizer.clone())
1143 }
1144 fn name(&self) -> String {
1145 self.model_id.clone()
1146 }
1147 fn reset_non_granular_state(&self) {
1148 if let Some(s) = self.non_granular_state.as_ref() {
1149 *self.cache().full().get_scalings_cache() = None;
1150 *get_mut_arcmutex!(s.non_granular_index) = 0;
1151 }
1152 }
1153 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1154 self.metadata.clone()
1155 }
1156 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1157 self.generation_defaults.clone()
1158 }
1159 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1160 Some(&*self.mapper)
1161 }
1162}
1163
1164impl crate::speculative::driver::SpeculativePipelineExt for NormalPipeline {
1165 fn has_speculative_proposer(&self) -> bool {
1166 self.model.has_speculative_proposer()
1167 }
1168
1169 fn speculative_proposal_len(&self) -> Option<usize> {
1170 self.model.speculative_proposal_len()
1171 }
1172
1173 fn speculative_target_hiddens(
1174 &self,
1175 rows: &[(usize, usize)],
1176 ) -> hanzo_ml::Result<Option<Tensor>> {
1177 self.model.speculative_target_hiddens(rows)
1178 }
1179
1180 fn speculative_propose(
1181 &mut self,
1182 ctx: crate::speculative::SpeculativeProposeBatchCtx<'_>,
1183 ) -> hanzo_ml::Result<Option<crate::speculative::SpeculativeProposalBatch>> {
1184 self.model.speculative_propose(ctx)
1185 }
1186
1187 fn build_speculative_verify_inputs(
1188 &self,
1189 input_meta: InputMetadata,
1190 ) -> hanzo_ml::Result<Box<dyn Any>> {
1191 Ok(Box::new(ModelInputs {
1192 input_ids: input_meta.input,
1193 input_ids_full: None,
1194 seqlen_offsets: input_meta.positions,
1195 seqlen_offsets_full: None,
1196 context_lens: input_meta.context_lens,
1197 position_ids: input_meta.position_ids,
1198 paged_attn_meta: input_meta.paged_attn_meta,
1199 flash_meta: input_meta.flash_meta,
1200 flash_meta_full: None,
1201 }))
1202 }
1203}
1204
1205#[async_trait::async_trait]
1206impl Pipeline for NormalPipeline {
1207 fn forward_inputs(
1208 &mut self,
1209 inputs: Box<dyn Any>,
1210 return_raw_logits: bool,
1211 ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
1212 let ModelInputs {
1213 input_ids,
1214 input_ids_full,
1215 seqlen_offsets,
1216 seqlen_offsets_full,
1217 context_lens,
1218 position_ids,
1219 paged_attn_meta,
1220 flash_meta,
1221 flash_meta_full,
1222 } = *inputs.downcast().expect("Downcast failed.");
1223 let metadata = self.get_metadata();
1224 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1225 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1226 (Some(_), None) => {
1227 hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1229 }
1230 (None, Some(_)) => {
1231 hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1233 }
1234 (None, None) => None,
1235 };
1236 let logits = match self.model.is_xlora() {
1237 false => {
1238 let paged_attn_meta = paged_attn_meta
1239 .as_ref()
1240 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1241
1242 self.model.forward(
1243 &input_ids,
1244 &seqlen_offsets,
1245 context_lens,
1246 position_ids,
1247 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1248 &flash_meta,
1249 )?
1250 }
1251 true => self.model.xlora_forward(
1252 &input_ids,
1253 input_ids_full.as_ref().unwrap_or(&input_ids),
1254 &seqlen_offsets,
1255 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1256 self.no_kv_cache,
1257 &self.non_granular_state,
1258 context_lens,
1259 position_ids,
1260 &flash_meta,
1261 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1262 )?,
1263 };
1264 if return_raw_logits {
1265 Ok(ForwardInputsResult::RawLogits { logits })
1266 } else {
1267 Ok(ForwardInputsResult::CausalGeneration { logits })
1268 }
1269 }
1270 fn attach_speculative(
1271 &mut self,
1272 config: crate::speculative::SpeculativeConfig,
1273 ) -> hanzo_ml::Result<()> {
1274 if matches!(config, crate::speculative::SpeculativeConfig::Mtp(_))
1275 && self.get_metadata().cache_engine.is_none()
1276 {
1277 hanzo_ml::bail!(
1278 "MTP speculative decoding currently requires PagedAttention for this pipeline."
1279 );
1280 }
1281 if let Some(info) = self.model.attach_speculative(config)? {
1282 self.model.log_speculative_attach(&info);
1283 }
1284 Ok(())
1285 }
1286
1287 #[allow(clippy::too_many_arguments)]
1288 async fn try_sample_speculative_causal_gen(
1289 &mut self,
1290 seqs: &mut [&mut Sequence],
1291 logits: &[Tensor],
1292 prefix_cacher: &mut PrefixCacheManagerV2,
1293 disable_eos_stop: bool,
1294 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1295 metadata: Option<crate::pipeline::text_models_inputs_processor::PagedAttentionMeta>,
1296 ) -> hanzo_ml::Result<bool> {
1297 if !self.model.has_speculative_proposer() {
1298 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1299 return Ok(false);
1300 }
1301
1302 let general_metadata = self.get_metadata();
1303 if let Some(cache_engine) = general_metadata.cache_engine.as_ref() {
1304 let Some(metadata) = metadata else {
1305 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1306 return Ok(false);
1307 };
1308 let cache = crate::speculative::cache::PagedSpeculativeCacheAccess::new(
1309 &metadata,
1310 cache_engine,
1311 );
1312 return crate::speculative::driver::try_sample_speculative_causal_gen(
1313 self,
1314 seqs,
1315 logits,
1316 prefix_cacher,
1317 disable_eos_stop,
1318 rng,
1319 &cache,
1320 )
1321 .await;
1322 }
1323
1324 crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1325 Ok(false)
1326 }
1327
1328 async fn sample_causal_gen(
1329 &self,
1330 seqs: &mut [&mut Sequence],
1331 logits: Vec<Tensor>,
1332 prefix_cacher: &mut PrefixCacheManagerV2,
1333 disable_eos_stop: bool,
1334 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1335 ) -> Result<(), hanzo_ml::Error> {
1336 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1337 }
1338 fn category(&self) -> ModelCategory {
1339 ModelCategory::Text
1340 }
1341}
1342
1343impl AnyMoePipelineMixin for NormalPipeline {
1344 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
1345 self.model.finish_training(gate_model_id)
1346 }
1347 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1348 self.model.get_vars()
1349 }
1350 fn amoe_base_model_trainable_params(&self) -> usize {
1351 self.model.trainable_params()
1352 }
1353 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1354 self.model.take_cached_gating_outputs()
1355 }
1356 fn amoe_create_layers(
1357 &mut self,
1358 model_ids: Vec<String>,
1359 token: &TokenSource,
1360 revision: Option<String>,
1361 match_regex: &str,
1362 config: crate::amoe::AnyMoeConfig,
1363 dtype: hanzo_ml::DType,
1364 dev: &Device,
1365 (prefix, mlp): (String, String),
1366 layers: Vec<usize>,
1367 expert_type: AnyMoeExpertType,
1368 silent: bool,
1369 gate_model_id: Option<String>,
1370 ) -> hanzo_ml::Result<()> {
1371 let mut vbs = Vec::new();
1372 let regex = Regex::new(match_regex).map_err(hanzo_ml::Error::msg)?;
1374 for model_id in model_ids {
1375 let model_id_str = &model_id;
1376 let model_id = Path::new(&model_id);
1377
1378 let api = {
1379 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1380 let mut api = ApiBuilder::from_cache(cache)
1381 .with_progress(!silent)
1382 .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1383 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1384 api = api.with_cache_dir(cache_dir);
1385 }
1386 api.build().map_err(hanzo_ml::Error::msg)?
1387 };
1388 let revision = revision.clone().unwrap_or("main".to_string());
1389 let api = api.repo(Repo::with_revision(
1390 model_id_str.clone(),
1391 RepoType::Model,
1392 revision.clone(),
1393 ));
1394
1395 let mut filenames = vec![];
1396 for rfilename in api_dir_list!(api, model_id, true, &revision)
1397 .filter(|x| x.ends_with(".safetensors"))
1398 {
1399 filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1400 }
1401
1402 let regex = regex.clone();
1403 let match_regex_clone = match_regex.to_string();
1404 let layers_clone = layers.clone();
1405 let vb = from_mmaped_safetensors(
1406 filenames,
1407 vec![],
1408 Some(dtype),
1409 dev,
1410 vec![None],
1411 silent,
1412 None,
1413 move |key| {
1414 if regex.is_match(&key) {
1415 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1418 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1419 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1420 .parse::<usize>()
1421 .unwrap();
1422 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1423 } else {
1424 false
1425 }
1426 },
1427 Arc::new(|_| DeviceForLoadTensor::Base),
1428 )?;
1429 vbs.push(vb);
1430 }
1431
1432 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1433 let model_id_str = &gate_model_id;
1434 let model_id = Path::new(&gate_model_id);
1435
1436 let api = {
1437 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1438 let mut api = ApiBuilder::from_cache(cache)
1439 .with_progress(!silent)
1440 .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1441 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1442 api = api.with_cache_dir(cache_dir);
1443 }
1444 api.build().map_err(hanzo_ml::Error::msg)?
1445 };
1446 let revision = revision.clone().unwrap_or("main".to_string());
1447 let api = api.repo(Repo::with_revision(
1448 model_id_str.clone(),
1449 RepoType::Model,
1450 revision.clone(),
1451 ));
1452
1453 let mut gate_filenames = vec![];
1454 for rfilename in api_dir_list!(api, model_id, true, &revision)
1455 .filter(|x| x.ends_with(".safetensors"))
1456 {
1457 gate_filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1458 }
1459 assert_eq!(
1460 gate_filenames.len(),
1461 1,
1462 "Gate model ID must contain only one .safetensors file"
1463 );
1464
1465 let vb = from_mmaped_safetensors(
1466 gate_filenames.clone(),
1467 vec![],
1468 Some(dtype),
1469 dev,
1470 vec![None],
1471 silent,
1472 None,
1473 |_| true,
1474 Arc::new(|_| DeviceForLoadTensor::Base),
1475 )?;
1476 info!(
1477 "Loaded gating layers from `{}`",
1478 gate_filenames[0].display()
1479 );
1480 Some(vb)
1481 } else {
1482 None
1483 };
1484
1485 self.model.create_anymoe_layers(
1486 vbs.clone(),
1487 config.clone(),
1488 (prefix.clone(), mlp.clone()),
1489 layers.clone(),
1490 expert_type.clone(),
1491 gate_vb.clone(),
1492 )?;
1493
1494 Ok(())
1495 }
1496 fn amoe_supported(&self) -> bool {
1497 self.model.amoe_supported()
1498 }
1499}