Skip to main content

hanzo_engine/pipeline/
normal.rs

1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
14    GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
15    MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16    Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::{UqffFullSer, WeightLoadingMode, WeightLoadingState};
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::{make_prompt_chunk, InputMetadata};
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38    progress::{new_multi_progress, ProgressScopeGuard},
39    tokens::get_token,
40    varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use hanzo_ml::{Device, Tensor, Var};
50use hanzo_quant::log::once_log_info;
51use hanzo_quant::{
52    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
53};
54use hf_hub::Cache;
55use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{debug, info, trace, warn};
68
69pub struct NormalPipeline {
70    model: Box<dyn NormalModel + Send + Sync>,
71    tokenizer: Arc<Tokenizer>,
72    no_kv_cache: bool,
73    chat_template: Arc<ChatTemplate>,
74    non_granular_state: Option<NonGranularState>,
75    model_id: String,
76    metadata: Arc<GeneralMetadata>,
77    topology: Option<Topology>,
78    silent: bool,
79    organization: IsqOrganization,
80    // For full UQFF serialization
81    template_filename: Option<PathBuf>,
82    generation_config: Option<PathBuf>,
83    generation_defaults: Option<crate::ModelGenerationDefaults>,
84    config: String,
85    imatrix: Option<PathBuf>,
86    mapper: Box<dyn DeviceMapper + Send + Sync>,
87}
88
89/// A loader for a "normal" (non-quantized) model.
90pub struct NormalLoader {
91    inner: Box<dyn NormalModelLoader>,
92    model_id: String,
93    config: NormalSpecificConfig,
94    xlora_model_id: Option<String>,
95    lora_adapter_ids: Option<Vec<String>>,
96    kind: ModelKind,
97    xlora_order: Option<Ordering>,
98    no_kv_cache: bool,
99    chat_template: Option<String>,
100    tokenizer_json: Option<String>,
101    tgt_non_granular_index: Option<usize>,
102    token_source: RwLock<Option<TokenSource>>,
103    revision: RwLock<Option<String>>,
104    from_uqff: RwLock<Option<Vec<PathBuf>>>,
105    jinja_explicit: Option<String>,
106    hf_cache_path: Option<PathBuf>,
107}
108
109#[derive(Default)]
110/// A builder for a loader for a "normal" (non-quantized) model.
111pub struct NormalLoaderBuilder {
112    model_id: Option<String>,
113    config: NormalSpecificConfig,
114    xlora_model_id: Option<String>,
115    lora_adapter_ids: Option<Vec<String>>,
116    kind: ModelKind,
117    xlora_order: Option<Ordering>,
118    no_kv_cache: bool,
119    chat_template: Option<String>,
120    tokenizer_json: Option<String>,
121    tgt_non_granular_index: Option<usize>,
122    jinja_explicit: Option<String>,
123    hf_cache_path: Option<PathBuf>,
124}
125
126#[derive(Clone, Default)]
127/// Config specific to loading a normal model.
128pub struct NormalSpecificConfig {
129    pub topology: Option<Topology>,
130    pub organization: IsqOrganization,
131    pub write_uqff: Option<PathBuf>,
132    pub from_uqff: Option<Vec<PathBuf>>,
133    pub imatrix: Option<PathBuf>,
134    pub calibration_file: Option<PathBuf>,
135    pub hf_cache_path: Option<PathBuf>,
136    pub matformer_config_path: Option<PathBuf>,
137    pub matformer_slice_name: Option<String>,
138}
139
140impl NormalLoaderBuilder {
141    pub fn new(
142        config: NormalSpecificConfig,
143        chat_template: Option<String>,
144        tokenizer_json: Option<String>,
145        model_id: Option<String>,
146        no_kv_cache: bool,
147        jinja_explicit: Option<String>,
148    ) -> Self {
149        Self {
150            config,
151            chat_template,
152            tokenizer_json,
153            model_id,
154            kind: ModelKind::Normal,
155            jinja_explicit,
156            no_kv_cache,
157            ..Default::default()
158        }
159    }
160
161    fn with_adapter(
162        mut self,
163        xlora_model_id: String,
164        xlora_order: Ordering,
165        no_kv_cache: bool,
166        tgt_non_granular_index: Option<usize>,
167    ) -> Self {
168        self.xlora_model_id = Some(xlora_model_id);
169        self.xlora_order = Some(xlora_order);
170        self.no_kv_cache = no_kv_cache;
171        self.tgt_non_granular_index = tgt_non_granular_index;
172        self.model_id = if let Some(id) = self.model_id {
173            Some(id)
174        } else {
175            info!(
176                "Using adapter base model ID: `{}`",
177                self.xlora_order.as_ref().unwrap().base_model_id
178            );
179            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
180        };
181        self
182    }
183
184    pub fn with_xlora(
185        mut self,
186        xlora_model_id: String,
187        xlora_order: Ordering,
188        no_kv_cache: bool,
189        tgt_non_granular_index: Option<usize>,
190    ) -> Self {
191        self.kind = ModelKind::Adapter {
192            adapter: AdapterKind::XLora,
193        };
194        self.with_adapter(
195            xlora_model_id,
196            xlora_order,
197            no_kv_cache,
198            tgt_non_granular_index,
199        )
200    }
201
202    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
203        self.kind = ModelKind::Adapter {
204            adapter: AdapterKind::Lora,
205        };
206        self.lora_adapter_ids = Some(lora_adapter_ids);
207        self
208    }
209
210    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
211        self.hf_cache_path = Some(hf_cache_path);
212        self
213    }
214
215    /// If the loader type is not specified, loader type is automatically determined from the
216    /// `architectures` array in the config.
217    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
218        let loader: Box<dyn NormalModelLoader> = match loader_tp {
219            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
220            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
221            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
222            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
223            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
224            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
225            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
226            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
227            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
228            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
229            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
230            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
231            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
232            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
233            Some(NormalLoaderType::GLM4MoeLite) => Box::new(GLM4MoeLiteLoader),
234            Some(NormalLoaderType::GLM4Moe) => Box::new(GLM4MoeLoader),
235            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
236            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
237            Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
238            Some(NormalLoaderType::GptOss) => Box::new(GptOssLoader),
239            Some(NormalLoaderType::Qwen3Next) => Box::new(Qwen3NextLoader),
240            None => Box::new(AutoNormalLoader),
241        };
242        Ok(Box::new(NormalLoader {
243            inner: loader,
244            model_id: self.model_id.unwrap(),
245            config: self.config,
246            xlora_model_id: self.xlora_model_id,
247            lora_adapter_ids: self.lora_adapter_ids,
248            kind: self.kind,
249            xlora_order: self.xlora_order,
250            no_kv_cache: self.no_kv_cache,
251            chat_template: self.chat_template,
252            tokenizer_json: self.tokenizer_json,
253            tgt_non_granular_index: self.tgt_non_granular_index,
254            jinja_explicit: self.jinja_explicit,
255            token_source: RwLock::new(None),
256            revision: RwLock::new(None),
257            from_uqff: RwLock::new(None),
258            hf_cache_path: self.hf_cache_path,
259        }))
260    }
261}
262
263impl Loader for NormalLoader {
264    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
265    fn load_model_from_hf(
266        &self,
267        revision: Option<String>,
268        token_source: TokenSource,
269        dtype: &dyn TryIntoDType,
270        device: &Device,
271        silent: bool,
272        mapper: DeviceMapSetting,
273        in_situ_quant: Option<IsqType>,
274        paged_attn_config: Option<PagedAttentionConfig>,
275    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
276        let _progress_guard = ProgressScopeGuard::new(silent);
277        let cache = self
278            .hf_cache_path
279            .clone()
280            .map(Cache::new)
281            .unwrap_or_default();
282        GLOBAL_HF_CACHE.get_or_init(|| cache);
283
284        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
285            LocalModelPaths,
286            &token_source,
287            revision.clone(),
288            self,
289            None,
290            None,
291            silent,
292            self.config.from_uqff.is_some()
293        );
294        *self
295            .token_source
296            .write()
297            .expect("Failed to write to token source") = Some(token_source);
298        *self.revision.write().expect("Failed to write to revision") = revision.clone();
299        if let Some(from_uqff) = self.config.from_uqff.clone() {
300            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
301        }
302        self.load_model_from_path(
303            &paths?,
304            dtype,
305            device,
306            silent,
307            mapper,
308            in_situ_quant,
309            paged_attn_config,
310        )
311    }
312
313    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
314    fn load_model_from_path(
315        &self,
316        paths: &Box<dyn ModelPaths>,
317        dtype: &dyn TryIntoDType,
318        device: &Device,
319        silent: bool,
320        mut mapper: DeviceMapSetting,
321        in_situ_quant: Option<IsqType>,
322        mut paged_attn_config: Option<PagedAttentionConfig>,
323    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
324        let _progress_guard = ProgressScopeGuard::new(silent);
325        let config = std::fs::read_to_string(paths.get_config_filename())?;
326
327        if !self.inner.supports_paged_attention(&config)? {
328            paged_attn_config = None;
329        }
330
331        debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
332
333        let use_nccl = hanzo_quant::distributed::use_nccl();
334
335        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
336            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
337            let WorkerTransferData::Init { id: _, worker_rank } = payload;
338            vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
339        } else if use_nccl {
340            vec![hanzo_ml::Device::new_cuda(0)?]
341        } else {
342            device_map::get_all_similar_devices(device)?
343        };
344        #[cfg(feature = "cuda")]
345        for device in &available_devices {
346            if let Device::Cuda(dev) = device {
347                unsafe { dev.disable_event_tracking() };
348            }
349        }
350        let device = if use_nccl || cfg!(feature = "ring") {
351            available_devices[0].clone()
352        } else {
353            device.clone()
354        };
355
356        // If auto, convert to Map if not using nccl
357        let mut max_kv_tokens: Option<usize> = None;
358        if use_nccl || cfg!(feature = "ring") {
359            mapper = DeviceMapSetting::DummyNccl {
360                nm_device: available_devices[0].clone(),
361            };
362        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
363            max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
364            // Initial dtype
365            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
366
367            // ISQ or UQFF: quantized path
368            // Match logic below where UQFF has priority
369            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
370                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
371                    let weight_pack_factor = {
372                        let ser_artifacts =
373                            unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
374                        let mut total_pack_factors = 0;
375                        let total_tensors = ser_artifacts.tensors().len();
376                        for (_, artifact) in ser_artifacts.tensors() {
377                            let artifact = artifact.data();
378                            // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
379                            let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
380                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
381                            {
382                                QuantizedSerdeType::Hqq => {
383                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384                                        .pack_factor(dtype)
385                                }
386                                QuantizedSerdeType::Gguf => {
387                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
388                                        .pack_factor(dtype)
389                                }
390                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
391                                QuantizedSerdeType::Unquant => 1,
392                                QuantizedSerdeType::Afq => {
393                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
394                                        .pack_factor(dtype)
395                                }
396                                QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
397                                QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
398                            };
399                            total_pack_factors += pack_factor;
400                        }
401
402                        total_pack_factors / total_tensors
403                    };
404
405                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
406                        &config,
407                        dtype,
408                        weight_pack_factor,
409                        None,
410                    )?;
411                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
412                        &config,
413                        dtype,
414                        weight_pack_factor,
415                        None,
416                    )?;
417                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
418                    (
419                        layer_sizes_in_bytes,
420                        non_mapped_size_in_bytes,
421                        layer_sizes_sum + non_mapped_size_in_bytes,
422                    )
423                } else if let Some(isq) = in_situ_quant {
424                    let weight_pack_factor = isq.pack_factor(dtype);
425                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
426                        &config,
427                        dtype,
428                        weight_pack_factor,
429                        None,
430                    )?;
431                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
432                        &config,
433                        dtype,
434                        weight_pack_factor,
435                        None,
436                    )?;
437                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
438                    (
439                        layer_sizes_in_bytes,
440                        non_mapped_size_in_bytes,
441                        layer_sizes_sum + non_mapped_size_in_bytes,
442                    )
443                } else {
444                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
445                    let weight_pack_factor =
446                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
447                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
448                        &config,
449                        dtype,
450                        weight_pack_factor,
451                        None,
452                    )?;
453                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
454                        &config,
455                        dtype,
456                        weight_pack_factor,
457                        None,
458                    )?;
459                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
460                    (
461                        layer_sizes_in_bytes,
462                        non_mapped_size_in_bytes,
463                        layer_sizes_sum + non_mapped_size_in_bytes,
464                    )
465                };
466
467            let new = auto_device_map::get_device_layers(
468                &*self.inner,
469                &config,
470                self.inner.num_layers(&config)?,
471                layer_sizes_in_bytes,
472                non_mapped_size_in_bytes,
473                total_model_size_in_bytes,
474                &available_devices,
475                dtype,
476                &params,
477                paged_attn_config.as_ref(),
478            )?;
479            mapper = DeviceMapSetting::Map(new);
480        }
481
482        let pipeline_mapper = mapper.into_mapper(
483            self.inner.num_layers(&config)?,
484            &device,
485            self.config.topology.as_ref(),
486            &available_devices,
487        )?;
488        let mapper = mapper.into_mapper(
489            self.inner.num_layers(&config)?,
490            &device,
491            self.config.topology.as_ref(),
492            &available_devices,
493        )?;
494        let mut layer_devices = Vec::new();
495        for layer in 0..self.inner.num_layers(&config)? {
496            let device = mapper.device_for(layer, false).cloned();
497            layer_devices.push(device);
498        }
499        let dtype = mapper.get_min_dtype(dtype)?;
500
501        // TODO: PagedAttention is not supported with CPU for now.
502        // This check is not really necessary because `get_device_layers` should prevent it.
503        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
504        if mapping_uses_cpu && paged_attn_config.is_some() {
505            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
506            paged_attn_config = None;
507        }
508
509        trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
510        if crate::using_flash_attn() {
511            once_log_info("FlashAttention is enabled.");
512        }
513
514        let topology_overrides = self
515            .config
516            .topology
517            .as_ref()
518            .map(|topology| {
519                topology
520                    .pattern_overrides()
521                    .into_iter()
522                    .map(|(regex, layer)| ImmediateIsqOverride {
523                        predicate: regex,
524                        ty: layer.isq,
525                        device: layer.device.clone(),
526                    })
527                    .collect::<Vec<_>>()
528            })
529            .unwrap_or_default();
530        let has_override_isq = topology_overrides
531            .iter()
532            .any(|override_entry| override_entry.ty.is_some());
533        let topology_requires_post_quant = self
534            .config
535            .topology
536            .as_ref()
537            .is_some_and(|topology| topology.requires_post_quantization());
538
539        let allow_immediate_cli = self.config.imatrix.is_none()
540            && self.config.calibration_file.is_none()
541            && in_situ_quant.is_some();
542
543        let mut immediate_ty = None;
544        let mut immediate_predicates = Vec::new();
545        if allow_immediate_cli {
546            immediate_ty = in_situ_quant;
547            immediate_predicates =
548                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
549                    self.inner.immediate_isq_predicates_moqe(&config)?
550                } else {
551                    self.inner.immediate_isq_predicates(&config)?
552                };
553            info!("Applying ISQ to {in_situ_quant:?}");
554            if immediate_predicates.is_empty() {
555                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
556            }
557        }
558
559        let use_immediate = allow_immediate_cli || has_override_isq;
560        if use_immediate {
561            let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
562            info!("Applying immediate ISQ in parallel on {num_threads} threads.");
563            hanzo_quant::set_immediate_isq_with_pool(
564                immediate_ty,
565                immediate_predicates.clone(),
566                topology_overrides.clone(),
567                pool,
568            );
569        }
570
571        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
572        let mut loading_isq = if use_immediate {
573            false
574        } else {
575            in_situ_quant.is_some()
576        };
577        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
578            loading_isq = true;
579        }
580        loading_isq |= topology_requires_post_quant;
581        loading_isq |= self.config.from_uqff.is_some();
582
583        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
584            anyhow::bail!(
585                "`imatrix` and `calibration_file` were both specified, this is not allowed."
586            );
587        }
588
589        // Load onto the regular device if not using isq or if the calibration file is specified.
590        // For immediate ISQ on discrete GPUs, load to CPU: the mapper will set the correct target
591        // device per-layer, and linear constructors will override to CPU for ISQ-targeted weights.
592        // On integrated/unified memory systems (e.g. Grace Blackwell), CPU and GPU share memory,
593        // so we load directly to the device.
594        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
595            loading_isq = false;
596            if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
597                Device::Cpu
598            } else {
599                device.clone()
600            }
601        } else {
602            Device::Cpu
603        };
604
605        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
606
607        let attention_mechanism = if paged_attn_config.is_some() {
608            AttentionImplementation::PagedAttention
609        } else {
610            AttentionImplementation::Eager
611        };
612
613        let multi_progress = Arc::new(new_multi_progress());
614
615        // Load matformer slicing config if provided
616        let matformer_slicing_config = if let Some(matformer_path) =
617            &self.config.matformer_config_path
618        {
619            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
620            info!("Loading Matformer config from {:?}", matformer_path);
621            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
622
623            if let Some(slice_name) = &self.config.matformer_slice_name {
624                info!("Using Matformer slice: {}", slice_name);
625                Some(MatformerSliceConfig::new(slice_name.clone(), config))
626            } else {
627                // If no slice name is provided but config exists, we'll need to handle this
628                // For now, return None and let the model handle the default slice selection
629                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
630                None
631            }
632        } else {
633            None
634        };
635
636        info!(
637            "{}",
638            WeightLoadingMode::from(WeightLoadingState {
639                from_uqff: self.config.from_uqff.is_some(),
640                loading_isq,
641                immediate_isq: use_immediate,
642                write_uqff: self.config.write_uqff.is_some(),
643            })
644            .message("model")
645        );
646
647        let mut model = if use_nccl || cfg!(feature = "ring") {
648            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
649                dtype,
650                &device,
651                &available_devices,
652                silent,
653                &config,
654                loading_isq,
655                self.config.from_uqff.is_some(),
656                self.config.organization,
657                &*self.inner,
658                paths.as_ref(),
659            )?;
660
661            // Special case for where things can be more optimially loaded.
662            match self.kind {
663                ModelKind::Normal => normal_model_loader_sharded!(
664                    sharded_vb,
665                    config,
666                    self.inner,
667                    mapper,
668                    loading_isq,
669                    device.clone(),
670                    attention_mechanism,
671                    multi_progress.clone(),
672                    matformer_slicing_config.clone(),
673                ),
674                ModelKind::Adapter {
675                    adapter: AdapterKind::XLora,
676                } => xlora_model_loader!(
677                    paths,
678                    Some(dtype),
679                    &load_device,
680                    layer_devices.clone(),
681                    config,
682                    self.inner,
683                    silent,
684                    mapper,
685                    loading_isq,
686                    device.clone(),
687                    multi_progress.clone(),
688                    matformer_slicing_config.clone(),
689                ),
690                ModelKind::Adapter {
691                    adapter: AdapterKind::Lora,
692                } => lora_model_loader!(
693                    paths,
694                    Some(dtype),
695                    &load_device,
696                    layer_devices.clone(),
697                    config,
698                    self.inner,
699                    silent,
700                    mapper,
701                    loading_isq,
702                    self.config.from_uqff.is_some(),
703                    device.clone(),
704                    attention_mechanism,
705                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
706                    multi_progress.clone(),
707                    matformer_slicing_config.clone(),
708                ),
709                _ => unreachable!(),
710            }
711        } else {
712            match self.kind {
713                ModelKind::Normal => normal_model_loader!(
714                    paths,
715                    Some(dtype),
716                    &load_device,
717                    layer_devices.clone(),
718                    config,
719                    self.inner,
720                    silent,
721                    mapper,
722                    loading_isq,
723                    self.config.from_uqff.is_some(),
724                    device.clone(),
725                    attention_mechanism,
726                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
727                    multi_progress.clone(),
728                    matformer_slicing_config.clone(),
729                ),
730                ModelKind::Adapter {
731                    adapter: AdapterKind::XLora,
732                } => xlora_model_loader!(
733                    paths,
734                    Some(dtype),
735                    &load_device,
736                    layer_devices.clone(),
737                    config,
738                    self.inner,
739                    silent,
740                    mapper,
741                    loading_isq,
742                    device.clone(),
743                    multi_progress.clone(),
744                    matformer_slicing_config.clone(),
745                ),
746                ModelKind::Adapter {
747                    adapter: AdapterKind::Lora,
748                } => lora_model_loader!(
749                    paths,
750                    Some(dtype),
751                    &load_device,
752                    layer_devices.clone(),
753                    config,
754                    self.inner,
755                    silent,
756                    mapper,
757                    loading_isq,
758                    self.config.from_uqff.is_some(),
759                    device.clone(),
760                    attention_mechanism,
761                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
762                    multi_progress.clone(),
763                    matformer_slicing_config.clone(),
764                ),
765                _ => unreachable!(),
766            }
767        };
768
769        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
770        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
771            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
772                Ok(conf) => Some(conf),
773                Err(e) => {
774                    warn!("Failed to parse generation_config.json: {}", e);
775                    None
776                }
777            }
778        });
779
780        let chat_template_explicit = paths
781            .get_chat_template_explicit()
782            .as_ref()
783            .map(|x| x.to_string_lossy().to_string());
784        let chat_template = get_chat_template(
785            paths,
786            self.jinja_explicit.as_ref(),
787            chat_template_explicit.as_ref(),
788            self.chat_template.as_ref(),
789            None,
790        );
791
792        if let Some(calibration_file) = &self.config.calibration_file {
793            let calibration_data = std::fs::read_to_string(calibration_file)?;
794            // Tokenize, don't add bos yet
795            let tokens = tokenizer
796                .encode_fast(calibration_data, false)
797                .map_err(anyhow::Error::msg)?
798                .get_ids()
799                .to_vec();
800            info!(
801                "Collecting imatrix from calibration file `{}` of {} tokens.",
802                calibration_file.display(),
803                tokens.len()
804            );
805            let bos_tok_id = chat_template
806                .bos_tok()
807                .as_deref()
808                .and_then(|tok| tokenizer.token_to_id(tok));
809
810            match self.config.organization {
811                IsqOrganization::Default => model.begin_track_stats()?,
812                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
813            }
814
815            const CHUNK_SIZE: usize = 1024;
816            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
817            let start = Instant::now();
818            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
819                let mut chunk = chunk.to_vec();
820                if let Some(bos_tok_id) = bos_tok_id {
821                    chunk.insert(0, bos_tok_id);
822                }
823                let chunk_len = chunk.len();
824
825                let start = Instant::now();
826                let inputs = make_prompt_chunk(
827                    0,
828                    vec![&chunk],
829                    &[0],
830                    &load_device,
831                    None,
832                    false,
833                    None,
834                    Some(pipeline_mapper.as_ref()),
835                    None,
836                    model.config().sliding_window,
837                )?;
838
839                model.forward(
840                    &inputs.input.to_device(model.device())?,
841                    &inputs.positions,
842                    inputs.context_lens.clone(),
843                    inputs.position_ids.clone(),
844                    None,
845                    &inputs.flash_meta.clone(),
846                )?;
847
848                match model.cache_mut() {
849                    EitherCache::Full(full) => {
850                        for layer in &mut *full.lock() {
851                            *layer = None
852                        }
853                    }
854                    EitherCache::Normal(normal) => {
855                        for layer in &mut *normal.lock().unwrap().0 {
856                            layer.reset();
857                        }
858                    }
859                    EitherCache::Hybrid(hybrid) => {
860                        hybrid.lock().unwrap().reset();
861                    }
862                }
863
864                let end = Instant::now();
865                info!(
866                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
867                    i + 1,
868                    end.duration_since(start).as_secs_f32()
869                );
870            }
871            load_device.synchronize()?;
872            let end = Instant::now();
873            info!(
874                "Finished collecting imatrix in {:.2}s",
875                end.duration_since(start).as_secs_f32()
876            );
877        }
878
879        // Only if loading from UQFF
880        let should_serialize = self.config.write_uqff.is_some();
881        let should_quantize_pass = loading_isq;
882
883        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
884            let imatrix_source = if should_quantize_pass {
885                match (
886                    self.config.imatrix.as_ref(),
887                    self.config.calibration_file.is_some(),
888                ) {
889                    (None, false) => None,
890                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
891                    (None, true) => Some(ImatrixDataSource::Collected),
892                    (Some(_), true) => unreachable!(),
893                }
894            } else {
895                None
896            };
897
898            if should_quantize_pass {
899                debug!("Applying ISQ to all ranks.");
900            } else {
901                debug!("Serializing existing ISQ tensors without additional quantization.");
902            }
903
904            let multi_progress = Arc::new(new_multi_progress());
905
906            model.quantize(
907                in_situ_quant,
908                model.device().clone(),
909                self.config.topology.as_ref(),
910                silent,
911                imatrix_source,
912                self.config.organization,
913                should_quantize_pass,
914                self.config.write_uqff.as_ref(),
915                UqffFullSer {
916                    tokenizer: &tokenizer,
917                    template_filename: paths.get_template_filename(),
918                    generation_config: paths.get_gen_conf_filename(),
919                    config: config.clone(),
920                    processor_filename: &None,
921                    preprocessor_filename: &None,
922                    modules: None,
923                    module_paths: None,
924                },
925                multi_progress.clone(),
926            )?;
927        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
928            model.load_from_artifacts(
929                device.clone(),
930                self.config.topology.as_ref(),
931                silent,
932                from_uqff,
933            )?;
934        }
935
936        let paged_attn_config = if matches!(
937            self.kind,
938            ModelKind::Adapter {
939                adapter: AdapterKind::XLora
940            }
941        ) {
942            warn!(
943                "Adapter parallel_models do not currently support PagedAttention, running without"
944            );
945            None
946        } else {
947            paged_attn_config
948        };
949
950        let model_metadata = model.model_config();
951        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
952            let cache_config = calculate_cache_config(
953                paged_attn_config.mem_gpu,
954                paged_attn_config.block_size,
955                dtype,
956                paged_attn_config.cache_type,
957                model_metadata.as_ref(),
958                &device,
959                &pipeline_mapper
960                    .get_unique_devices()
961                    .into_iter()
962                    .map(Some)
963                    .collect::<Vec<_>>(),
964                silent,
965                None,
966                max_kv_tokens,
967            )?;
968
969            let mut layer_devices = Vec::new();
970            for layer in 0..self.inner.num_layers(&config)? {
971                let device = model.get_layers().1.device_for(layer, false).cloned();
972                layer_devices.push(device);
973            }
974            let cache_engine = CacheEngine::new(
975                model_metadata.as_ref(),
976                &cache_config,
977                dtype,
978                model.device(),
979                layer_devices.clone(),
980            )?;
981
982            (Some(cache_config), Some(cache_engine))
983        } else {
984            (None, None)
985        };
986
987        let max_seq_len = model.max_seq_len();
988        let llg_factory = build_llg_factory(tokenizer.clone())?;
989        let num_hidden_layers = match model.cache() {
990            EitherCache::Full(full) => full.lock().len(),
991            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
992            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
993        };
994        let generation_defaults = gen_conf
995            .as_ref()
996            .and_then(GenerationConfig::generation_defaults);
997        let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
998        let sliding_window = model.config().sliding_window;
999        Ok(Arc::new(Mutex::new(NormalPipeline {
1000            model,
1001            tokenizer: tokenizer.into(),
1002            no_kv_cache: self.no_kv_cache,
1003            chat_template: Arc::new(chat_template),
1004            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
1005                NonGranularState {
1006                    non_granular_index: Arc::new(Mutex::new(0)),
1007                    tgt_non_granular_index,
1008                }
1009            }),
1010            model_id: self.model_id.clone(),
1011            metadata: Arc::new(GeneralMetadata {
1012                max_seq_len,
1013                llg_factory: Some(llg_factory),
1014                no_kv_cache: self.no_kv_cache,
1015                no_prefix_cache: is_xlora,
1016                num_hidden_layers,
1017                eos_tok: eos,
1018                kind: self.kind.clone(),
1019                is_xlora,
1020                activation_dtype: dtype,
1021                sliding_window,
1022                cache_config,
1023                cache_engine,
1024                model_metadata: Some(model_metadata),
1025                modalities: Modalities {
1026                    input: vec![SupportedModality::Text],
1027                    output: vec![SupportedModality::Text],
1028                },
1029            }),
1030            topology: self.config.topology.clone(),
1031            silent,
1032            organization: self.config.organization,
1033            template_filename: paths.get_template_filename().clone(),
1034            generation_config: paths.get_gen_conf_filename().cloned(),
1035            generation_defaults,
1036            config,
1037            imatrix: self.config.imatrix.clone(),
1038            mapper: pipeline_mapper,
1039        })))
1040    }
1041
1042    fn get_id(&self) -> String {
1043        self.model_id.clone()
1044    }
1045
1046    fn get_kind(&self) -> ModelKind {
1047        self.kind.clone()
1048    }
1049}
1050
1051impl PreProcessingMixin for NormalPipeline {
1052    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1053        Some(self.chat_template.clone())
1054    }
1055    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1056        None
1057    }
1058}
1059
1060impl IsqPipelineMixin for NormalPipeline {
1061    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1062        let device = self.device().clone();
1063        let multi_progress = Arc::new(new_multi_progress());
1064        self.model.quantize(
1065            Some(dtype),
1066            device.clone(),
1067            self.topology.as_ref(),
1068            self.silent,
1069            self.imatrix.as_ref().map(ImatrixDataSource::File),
1070            self.organization,
1071            true,
1072            None,
1073            UqffFullSer {
1074                tokenizer: &self.tokenizer,
1075                template_filename: &self.template_filename,
1076                generation_config: self.generation_config.as_ref(),
1077                config: self.config.clone(),
1078                processor_filename: &None,
1079                preprocessor_filename: &None,
1080                modules: None,
1081                module_paths: None,
1082            },
1083            multi_progress.clone(),
1084        )?;
1085        Ok(())
1086    }
1087}
1088
1089impl CacheManagerMixin for NormalPipeline {
1090    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1091        match self.model.cache() {
1092            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1093            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1094            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1095        }
1096    }
1097    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1098        match self.model.cache() {
1099            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1100            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1101            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1102        }
1103    }
1104    fn set_none_cache(
1105        &self,
1106        seqs: &mut [&mut Sequence],
1107        reset_non_granular: bool,
1108        modify_draft_cache: bool,
1109        load_preallocated_cache: bool,
1110    ) {
1111        match self.model.cache() {
1112            EitherCache::Full(_) => {
1113                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1114            }
1115            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1116                self,
1117                seqs,
1118                modify_draft_cache,
1119                load_preallocated_cache,
1120            ),
1121            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1122                self,
1123                seqs,
1124                modify_draft_cache,
1125                load_preallocated_cache,
1126            ),
1127        }
1128        if reset_non_granular {
1129            self.reset_non_granular_state()
1130        }
1131    }
1132    fn cache(&self) -> &EitherCache {
1133        self.model.cache()
1134    }
1135}
1136
1137impl MetadataMixin for NormalPipeline {
1138    fn device(&self) -> Device {
1139        self.model.device().clone()
1140    }
1141    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1142        Some(self.tokenizer.clone())
1143    }
1144    fn name(&self) -> String {
1145        self.model_id.clone()
1146    }
1147    fn reset_non_granular_state(&self) {
1148        if let Some(s) = self.non_granular_state.as_ref() {
1149            *self.cache().full().get_scalings_cache() = None;
1150            *get_mut_arcmutex!(s.non_granular_index) = 0;
1151        }
1152    }
1153    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1154        self.metadata.clone()
1155    }
1156    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
1157        self.generation_defaults.clone()
1158    }
1159    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1160        Some(&*self.mapper)
1161    }
1162}
1163
1164impl crate::speculative::driver::SpeculativePipelineExt for NormalPipeline {
1165    fn has_speculative_proposer(&self) -> bool {
1166        self.model.has_speculative_proposer()
1167    }
1168
1169    fn speculative_proposal_len(&self) -> Option<usize> {
1170        self.model.speculative_proposal_len()
1171    }
1172
1173    fn speculative_target_hiddens(
1174        &self,
1175        rows: &[(usize, usize)],
1176    ) -> hanzo_ml::Result<Option<Tensor>> {
1177        self.model.speculative_target_hiddens(rows)
1178    }
1179
1180    fn speculative_propose(
1181        &mut self,
1182        ctx: crate::speculative::SpeculativeProposeBatchCtx<'_>,
1183    ) -> hanzo_ml::Result<Option<crate::speculative::SpeculativeProposalBatch>> {
1184        self.model.speculative_propose(ctx)
1185    }
1186
1187    fn build_speculative_verify_inputs(
1188        &self,
1189        input_meta: InputMetadata,
1190    ) -> hanzo_ml::Result<Box<dyn Any>> {
1191        Ok(Box::new(ModelInputs {
1192            input_ids: input_meta.input,
1193            input_ids_full: None,
1194            seqlen_offsets: input_meta.positions,
1195            seqlen_offsets_full: None,
1196            context_lens: input_meta.context_lens,
1197            position_ids: input_meta.position_ids,
1198            paged_attn_meta: input_meta.paged_attn_meta,
1199            flash_meta: input_meta.flash_meta,
1200            flash_meta_full: None,
1201        }))
1202    }
1203}
1204
1205#[async_trait::async_trait]
1206impl Pipeline for NormalPipeline {
1207    fn forward_inputs(
1208        &mut self,
1209        inputs: Box<dyn Any>,
1210        return_raw_logits: bool,
1211    ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
1212        let ModelInputs {
1213            input_ids,
1214            input_ids_full,
1215            seqlen_offsets,
1216            seqlen_offsets_full,
1217            context_lens,
1218            position_ids,
1219            paged_attn_meta,
1220            flash_meta,
1221            flash_meta_full,
1222        } = *inputs.downcast().expect("Downcast failed.");
1223        let metadata = self.get_metadata();
1224        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1225            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1226            (Some(_), None) => {
1227                // This can happen if Rust-side user code is wrong
1228                hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1229            }
1230            (None, Some(_)) => {
1231                // This should never happen but we handle it anyway
1232                hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1233            }
1234            (None, None) => None,
1235        };
1236        let logits = match self.model.is_xlora() {
1237            false => {
1238                let paged_attn_meta = paged_attn_meta
1239                    .as_ref()
1240                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1241
1242                self.model.forward(
1243                    &input_ids,
1244                    &seqlen_offsets,
1245                    context_lens,
1246                    position_ids,
1247                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1248                    &flash_meta,
1249                )?
1250            }
1251            true => self.model.xlora_forward(
1252                &input_ids,
1253                input_ids_full.as_ref().unwrap_or(&input_ids),
1254                &seqlen_offsets,
1255                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1256                self.no_kv_cache,
1257                &self.non_granular_state,
1258                context_lens,
1259                position_ids,
1260                &flash_meta,
1261                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1262            )?,
1263        };
1264        if return_raw_logits {
1265            Ok(ForwardInputsResult::RawLogits { logits })
1266        } else {
1267            Ok(ForwardInputsResult::CausalGeneration { logits })
1268        }
1269    }
1270    fn attach_speculative(
1271        &mut self,
1272        config: crate::speculative::SpeculativeConfig,
1273    ) -> hanzo_ml::Result<()> {
1274        if matches!(config, crate::speculative::SpeculativeConfig::Mtp(_))
1275            && self.get_metadata().cache_engine.is_none()
1276        {
1277            hanzo_ml::bail!(
1278                "MTP speculative decoding currently requires PagedAttention for this pipeline."
1279            );
1280        }
1281        if let Some(info) = self.model.attach_speculative(config)? {
1282            self.model.log_speculative_attach(&info);
1283        }
1284        Ok(())
1285    }
1286
1287    #[allow(clippy::too_many_arguments)]
1288    async fn try_sample_speculative_causal_gen(
1289        &mut self,
1290        seqs: &mut [&mut Sequence],
1291        logits: &[Tensor],
1292        prefix_cacher: &mut PrefixCacheManagerV2,
1293        disable_eos_stop: bool,
1294        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1295        metadata: Option<crate::pipeline::text_models_inputs_processor::PagedAttentionMeta>,
1296    ) -> hanzo_ml::Result<bool> {
1297        if !self.model.has_speculative_proposer() {
1298            crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1299            return Ok(false);
1300        }
1301
1302        let general_metadata = self.get_metadata();
1303        if let Some(cache_engine) = general_metadata.cache_engine.as_ref() {
1304            let Some(metadata) = metadata else {
1305                crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1306                return Ok(false);
1307            };
1308            let cache = crate::speculative::cache::PagedSpeculativeCacheAccess::new(
1309                &metadata,
1310                cache_engine,
1311            );
1312            return crate::speculative::driver::try_sample_speculative_causal_gen(
1313                self,
1314                seqs,
1315                logits,
1316                prefix_cacher,
1317                disable_eos_stop,
1318                rng,
1319                &cache,
1320            )
1321            .await;
1322        }
1323
1324        crate::speculative::driver::clear_staged_speculative_tokens(seqs);
1325        Ok(false)
1326    }
1327
1328    async fn sample_causal_gen(
1329        &self,
1330        seqs: &mut [&mut Sequence],
1331        logits: Vec<Tensor>,
1332        prefix_cacher: &mut PrefixCacheManagerV2,
1333        disable_eos_stop: bool,
1334        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1335    ) -> Result<(), hanzo_ml::Error> {
1336        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1337    }
1338    fn category(&self) -> ModelCategory {
1339        ModelCategory::Text
1340    }
1341}
1342
1343impl AnyMoePipelineMixin for NormalPipeline {
1344    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
1345        self.model.finish_training(gate_model_id)
1346    }
1347    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1348        self.model.get_vars()
1349    }
1350    fn amoe_base_model_trainable_params(&self) -> usize {
1351        self.model.trainable_params()
1352    }
1353    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1354        self.model.take_cached_gating_outputs()
1355    }
1356    fn amoe_create_layers(
1357        &mut self,
1358        model_ids: Vec<String>,
1359        token: &TokenSource,
1360        revision: Option<String>,
1361        match_regex: &str,
1362        config: crate::amoe::AnyMoeConfig,
1363        dtype: hanzo_ml::DType,
1364        dev: &Device,
1365        (prefix, mlp): (String, String),
1366        layers: Vec<usize>,
1367        expert_type: AnyMoeExpertType,
1368        silent: bool,
1369        gate_model_id: Option<String>,
1370    ) -> hanzo_ml::Result<()> {
1371        let mut vbs = Vec::new();
1372        // Precompile regex here
1373        let regex = Regex::new(match_regex).map_err(hanzo_ml::Error::msg)?;
1374        for model_id in model_ids {
1375            let model_id_str = &model_id;
1376            let model_id = Path::new(&model_id);
1377
1378            let api = {
1379                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1380                let mut api = ApiBuilder::from_cache(cache)
1381                    .with_progress(!silent)
1382                    .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1383                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1384                    api = api.with_cache_dir(cache_dir);
1385                }
1386                api.build().map_err(hanzo_ml::Error::msg)?
1387            };
1388            let revision = revision.clone().unwrap_or("main".to_string());
1389            let api = api.repo(Repo::with_revision(
1390                model_id_str.clone(),
1391                RepoType::Model,
1392                revision.clone(),
1393            ));
1394
1395            let mut filenames = vec![];
1396            for rfilename in api_dir_list!(api, model_id, true, &revision)
1397                .filter(|x| x.ends_with(".safetensors"))
1398            {
1399                filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1400            }
1401
1402            let regex = regex.clone();
1403            let match_regex_clone = match_regex.to_string();
1404            let layers_clone = layers.clone();
1405            let vb = from_mmaped_safetensors(
1406                filenames,
1407                vec![],
1408                Some(dtype),
1409                dev,
1410                vec![None],
1411                silent,
1412                None,
1413                move |key| {
1414                    if regex.is_match(&key) {
1415                        // Idx of the last char of the layer id, +1
1416                        // Assumes N.MLP
1417                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1418                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1419                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1420                            .parse::<usize>()
1421                            .unwrap();
1422                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1423                    } else {
1424                        false
1425                    }
1426                },
1427                Arc::new(|_| DeviceForLoadTensor::Base),
1428            )?;
1429            vbs.push(vb);
1430        }
1431
1432        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1433            let model_id_str = &gate_model_id;
1434            let model_id = Path::new(&gate_model_id);
1435
1436            let api = {
1437                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1438                let mut api = ApiBuilder::from_cache(cache)
1439                    .with_progress(!silent)
1440                    .with_token(get_token(token).map_err(hanzo_ml::Error::msg)?);
1441                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
1442                    api = api.with_cache_dir(cache_dir);
1443                }
1444                api.build().map_err(hanzo_ml::Error::msg)?
1445            };
1446            let revision = revision.clone().unwrap_or("main".to_string());
1447            let api = api.repo(Repo::with_revision(
1448                model_id_str.clone(),
1449                RepoType::Model,
1450                revision.clone(),
1451            ));
1452
1453            let mut gate_filenames = vec![];
1454            for rfilename in api_dir_list!(api, model_id, true, &revision)
1455                .filter(|x| x.ends_with(".safetensors"))
1456            {
1457                gate_filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
1458            }
1459            assert_eq!(
1460                gate_filenames.len(),
1461                1,
1462                "Gate model ID must contain only one .safetensors file"
1463            );
1464
1465            let vb = from_mmaped_safetensors(
1466                gate_filenames.clone(),
1467                vec![],
1468                Some(dtype),
1469                dev,
1470                vec![None],
1471                silent,
1472                None,
1473                |_| true,
1474                Arc::new(|_| DeviceForLoadTensor::Base),
1475            )?;
1476            info!(
1477                "Loaded gating layers from `{}`",
1478                gate_filenames[0].display()
1479            );
1480            Some(vb)
1481        } else {
1482            None
1483        };
1484
1485        self.model.create_anymoe_layers(
1486            vbs.clone(),
1487            config.clone(),
1488            (prefix.clone(), mlp.clone()),
1489            layers.clone(),
1490            expert_type.clone(),
1491            gate_vb.clone(),
1492        )?;
1493
1494        Ok(())
1495    }
1496    fn amoe_supported(&self) -> bool {
1497        self.model.amoe_supported()
1498    }
1499}