Skip to main content

hanzo_engine/pipeline/
gguf.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5    TokenSource,
6};
7use super::{
8    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9    MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::distributed::WorkerTransferData;
14use crate::gguf::{
15    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
16};
17use crate::gguf::{Content, GGUFArchitecture};
18use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
19use crate::lora::Ordering;
20use crate::paged_attention::{
21    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
22};
23use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
24use crate::pipeline::loaders::DeviceMappedModelLoader;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::ChatTemplate;
27use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
28use crate::prefix_cacher::PrefixCacheManagerV2;
29use crate::sequence::Sequence;
30use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
31use crate::utils::model_config as ModelConfig;
32use crate::utils::progress::ProgressScopeGuard;
33use crate::utils::tokenizer::get_tokenizer;
34use crate::xlora_models::NonGranularState;
35use crate::{
36    distributed, get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths,
37    PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
38};
39use crate::{
40    models::quantized_llama::ModelWeights as QLlama,
41    models::quantized_phi2::ModelWeights as QPhi,
42    models::quantized_phi3::ModelWeights as QPhi3,
43    models::quantized_qwen::ModelWeights as QQwen,
44    models::quantized_qwen3::ModelWeights as QQwen3,
45    models::quantized_qwen3_5_moe::ModelWeights as QQwen35,
46    models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
47    models::quantized_starcoder2::ModelWeights as QStarcoder2,
48    utils::tokens::get_token,
49    xlora_models::{XLoraQLlama, XLoraQPhi3},
50};
51use anyhow::{bail, Result};
52use either::Either;
53use hanzo_ml::{Device, Tensor};
54use hanzo_quant::IsqType;
55use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
56use rand_isaac::Isaac64Rng;
57use std::any::Any;
58use std::path::PathBuf;
59use std::str::FromStr;
60use std::sync::Arc;
61use std::{env, fs};
62use tokenizers::Tokenizer;
63use tokio::sync::Mutex;
64use tracing::{debug, info, warn};
65
66enum Model {
67    Llama(QLlama),
68    Phi2(QPhi),
69    XLoraLlama(XLoraQLlama),
70    XLoraPhi3(XLoraQPhi3),
71    Phi3(QPhi3),
72    Starcoder2(QStarcoder2),
73    Qwen(QQwen),
74    Qwen3(QQwen3),
75    Qwen3MoE(QQwen3MoE),
76    Qwen35(QQwen35),
77}
78
79pub struct GGUFPipeline {
80    model: Model,
81    tokenizer: Arc<Tokenizer>,
82    no_kv_cache: bool,
83    chat_template: Arc<ChatTemplate>,
84    model_id: String,
85    non_granular_state: Option<NonGranularState>,
86    metadata: Arc<GeneralMetadata>,
87    generation_defaults: Option<crate::ModelGenerationDefaults>,
88    mapper: Box<dyn DeviceMapper + Send + Sync>,
89}
90
91/// Loader for a GGUF model.
92pub struct GGUFLoader {
93    model_id: Option<String>,
94    quantized_model_id: String,
95    quantized_filenames: Vec<String>,
96    xlora_model_id: Option<String>,
97    xlora_order: Option<Ordering>,
98    no_kv_cache: bool,
99    chat_template: Option<String>,
100    kind: ModelKind,
101    tgt_non_granular_index: Option<usize>,
102    config: GGUFSpecificConfig,
103    jinja_explicit: Option<String>,
104    lora_adapter_ids: Option<Vec<String>>,
105}
106
107#[derive(Clone, Default)]
108/// Config for a GGUF loader.
109pub struct GGUFSpecificConfig {
110    pub topology: Option<Topology>,
111}
112
113#[derive(Default)]
114/// A builder for a GGUF loader.
115pub struct GGUFLoaderBuilder {
116    model_id: Option<String>,
117    quantized_model_id: String,
118    quantized_filenames: Vec<String>,
119    xlora_model_id: Option<String>,
120    kind: ModelKind,
121    xlora_order: Option<Ordering>,
122    no_kv_cache: bool,
123    chat_template: Option<String>,
124    tgt_non_granular_index: Option<usize>,
125    config: GGUFSpecificConfig,
126    jinja_explicit: Option<String>,
127}
128
129impl GGUFLoaderBuilder {
130    /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
131    /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
132    /// path and used over remote files, removing all remote accesses.
133    pub fn new(
134        chat_template: Option<String>,
135        tok_model_id: Option<String>,
136        quantized_model_id: String,
137        quantized_filenames: Vec<String>,
138        config: GGUFSpecificConfig,
139        no_kv_cache: bool,
140        jinja_explicit: Option<String>,
141    ) -> Self {
142        let kind = ModelKind::GgufQuantized {
143            quant: QuantizationKind::Gguf,
144        };
145
146        Self {
147            chat_template,
148            model_id: tok_model_id,
149            kind,
150            quantized_filenames,
151            quantized_model_id,
152            config,
153            jinja_explicit,
154            no_kv_cache,
155            ..Default::default()
156        }
157    }
158
159    fn with_adapter(
160        mut self,
161        xlora_model_id: String,
162        xlora_order: Ordering,
163        no_kv_cache: bool,
164        tgt_non_granular_index: Option<usize>,
165    ) -> Self {
166        self.xlora_model_id = Some(xlora_model_id);
167        self.xlora_order = Some(xlora_order);
168        self.no_kv_cache = no_kv_cache;
169        self.tgt_non_granular_index = tgt_non_granular_index;
170        self.model_id = if let Some(id) = self.model_id {
171            Some(id)
172        } else {
173            info!(
174                "Using adapter base model ID: `{}`",
175                self.xlora_order.as_ref().unwrap().base_model_id
176            );
177            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
178        };
179        self
180    }
181
182    pub fn with_xlora(
183        mut self,
184        xlora_model_id: String,
185        xlora_order: Ordering,
186        no_kv_cache: bool,
187        tgt_non_granular_index: Option<usize>,
188    ) -> Self {
189        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
190
191        self.with_adapter(
192            xlora_model_id,
193            xlora_order,
194            no_kv_cache,
195            tgt_non_granular_index,
196        )
197    }
198
199    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
200        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
201
202        self.with_adapter(lora_model_id, lora_order, false, None)
203    }
204
205    pub fn build(self) -> Box<dyn Loader> {
206        Box::new(GGUFLoader {
207            model_id: self.model_id,
208            xlora_model_id: self.xlora_model_id,
209            kind: self.kind,
210            xlora_order: self.xlora_order,
211            no_kv_cache: self.no_kv_cache,
212            chat_template: self.chat_template,
213            tgt_non_granular_index: self.tgt_non_granular_index,
214            quantized_filenames: self.quantized_filenames,
215            quantized_model_id: self.quantized_model_id,
216            config: self.config,
217            jinja_explicit: self.jinja_explicit,
218            lora_adapter_ids: None,
219        })
220    }
221}
222
223impl GGUFLoader {
224    #[allow(clippy::too_many_arguments)]
225    pub fn new(
226        model_id: Option<String>,
227        quantized_model_id: String,
228        quantized_filenames: Vec<String>,
229        xlora_model_id: Option<String>,
230        kind: ModelKind,
231        xlora_order: Option<Ordering>,
232        no_kv_cache: bool,
233        chat_template: Option<String>,
234        tgt_non_granular_index: Option<usize>,
235        config: GGUFSpecificConfig,
236        jinja_explicit: Option<String>,
237    ) -> Self {
238        let model_id = if let Some(id) = model_id {
239            Some(id)
240        } else if let Some(xlora_order) = xlora_order.clone() {
241            info!(
242                "Using adapter base model ID: `{}`",
243                xlora_order.base_model_id
244            );
245            Some(xlora_order.base_model_id.clone())
246        } else {
247            None
248        };
249        Self {
250            model_id,
251            quantized_model_id,
252            quantized_filenames,
253            xlora_model_id,
254            xlora_order,
255            no_kv_cache,
256            chat_template,
257            kind,
258            tgt_non_granular_index,
259            config,
260            jinja_explicit,
261            lora_adapter_ids: None,
262        }
263    }
264}
265
266impl Loader for GGUFLoader {
267    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
268    fn load_model_from_hf(
269        &self,
270        revision: Option<String>,
271        token_source: TokenSource,
272        dtype: &dyn TryIntoDType,
273        device: &Device,
274        silent: bool,
275        mapper: DeviceMapSetting,
276        in_situ_quant: Option<IsqType>,
277        paged_attn_config: Option<PagedAttentionConfig>,
278    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
279        let _progress_guard = ProgressScopeGuard::new(silent);
280        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
281            LocalModelPaths,
282            &token_source,
283            revision,
284            self,
285            self.quantized_model_id.clone(),
286            self.quantized_filenames.clone(),
287            silent
288        );
289
290        self.load_model_from_path(
291            &paths?,
292            dtype,
293            device,
294            silent,
295            mapper,
296            in_situ_quant,
297            paged_attn_config,
298        )
299    }
300
301    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
302    fn load_model_from_path(
303        &self,
304        paths: &Box<dyn ModelPaths>,
305        dtype: &dyn TryIntoDType,
306        device: &Device,
307        silent: bool,
308        mut mapper: DeviceMapSetting,
309        in_situ_quant: Option<IsqType>,
310        mut paged_attn_config: Option<PagedAttentionConfig>,
311    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
312        let _progress_guard = ProgressScopeGuard::new(silent);
313        if in_situ_quant.is_some() {
314            anyhow::bail!(
315                "You are trying to in-situ quantize a GGUF model. This will not do anything."
316            );
317        }
318
319        debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
320
321        let mut readers = Vec::new();
322        for filename in paths.get_weight_filenames() {
323            readers.push(std::fs::File::open(filename)?);
324        }
325        let mut readers = readers.iter_mut().collect::<Vec<_>>();
326        let model = Content::from_readers(&mut readers)?;
327
328        if !silent {
329            model.print_metadata()?;
330        }
331
332        let arch = model.arch();
333
334        // If auto, convert to Map
335        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
336
337        let mut max_kv_tokens: Option<usize> = None;
338
339        if let DeviceMapSetting::Auto(params) = mapper.clone() {
340            let devices = device_map::get_all_similar_devices(device)?;
341            // Initial dtype
342            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
343
344            let model = GgufDeviceMapLoaderInner {
345                model: &model,
346                arch,
347            };
348
349            let layer_sizes_in_bytes =
350                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
351            let non_mapped_size_in_bytes =
352                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
353            let total_model_size_in_bytes =
354                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
355
356            let new = model.get_device_layers(
357                "this is a dummy config!",
358                num_layers,
359                layer_sizes_in_bytes,
360                non_mapped_size_in_bytes,
361                total_model_size_in_bytes,
362                &devices,
363                dtype,
364                &params,
365                paged_attn_config.as_ref(),
366            )?;
367            max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
368            mapper = DeviceMapSetting::Map(new);
369        }
370
371        #[cfg(feature = "cuda")]
372        if let Device::Cuda(dev) = &device {
373            unsafe { dev.disable_event_tracking() };
374        }
375
376        let use_nccl = hanzo_quant::distributed::use_nccl();
377        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
378            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
379            let WorkerTransferData::Init { id: _, worker_rank } = payload;
380            vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
381        } else if use_nccl {
382            vec![hanzo_ml::Device::new_cuda(0)?]
383        } else {
384            device_map::get_all_similar_devices(device)?
385        };
386
387        let pipeline_mapper = mapper.into_mapper(
388            num_layers,
389            device,
390            self.config.topology.as_ref(),
391            &available_devices,
392        )?;
393        let mapper = mapper.into_mapper(
394            num_layers,
395            device,
396            self.config.topology.as_ref(),
397            &available_devices,
398        )?;
399        let mut layer_devices = Vec::new();
400        for layer in 0..num_layers {
401            let device = mapper.device_for(layer, false).cloned();
402            layer_devices.push(device);
403        }
404
405        // TODO: PagedAttention is not supported with CPU for now.
406        // This check is not really necessary because `get_device_layers` should prevent it.
407        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
408        if mapping_uses_cpu {
409            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
410            paged_attn_config = None;
411        }
412
413        let GgufTokenizerConversion {
414            tokenizer,
415            bos,
416            eos,
417            unk,
418        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
419            convert_gguf_to_hf_tokenizer(&model)?
420        } else {
421            GgufTokenizerConversion {
422                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
423                bos: None,
424                eos: None,
425                unk: None,
426            }
427        };
428
429        // Only load gguf chat template if there is nothing else
430        let gguf_chat_template =
431            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
432                get_gguf_chat_template(&model)?
433            } else {
434                None
435            };
436
437        let has_adapter = self.kind.is_adapted();
438        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
439
440        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
441            warn!("Adapter models do not currently support PagedAttention, running without");
442            None
443        } else {
444            paged_attn_config
445        };
446
447        let model_config_metadata: ContentConfig = (&model).into();
448        let internal_dtype = mapper.get_min_dtype(dtype)?;
449
450        let model_config = {
451            // Base config (quantization only):
452            let quant = ModelConfig::ParamsGGUF(
453                model,
454                (device, mapper).into(),
455                if paged_attn_config.is_some() {
456                    AttentionImplementation::PagedAttention
457                } else {
458                    AttentionImplementation::Eager
459                },
460                internal_dtype,
461            );
462
463            // With optional adapter config:
464            let mut adapter = None;
465            if has_adapter {
466                adapter.replace(ModelConfig::Adapter::try_new(
467                    paths, device, silent, is_xlora,
468                )?);
469            }
470
471            ModelConfig::ModelParams::new(quant, adapter)
472        };
473
474        // Config into model:
475        let model = match self.kind {
476            ModelKind::GgufQuantized { .. } => match arch {
477                GGUFArchitecture::Llama | GGUFArchitecture::Mistral3 => {
478                    Model::Llama(QLlama::try_from(model_config)?)
479                }
480                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
481                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
482                GGUFArchitecture::Starcoder2 => {
483                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
484                }
485                GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
486                GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
487                GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
488                GGUFArchitecture::Qwen35 | GGUFArchitecture::Qwen35MoE => {
489                    Model::Qwen35(QQwen35::try_from(model_config)?)
490                }
491                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
492            },
493            ModelKind::GgufAdapter { adapter, .. } => match arch {
494                GGUFArchitecture::Llama | GGUFArchitecture::Mistral3 => {
495                    Model::XLoraLlama(XLoraQLlama::try_from(model_config)?)
496                }
497                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
498                a => bail!(
499                    "Unsupported architecture `{a:?}` for GGUF {kind}",
500                    kind = adapter.pretty_name()
501                ),
502            },
503            _ => unreachable!(),
504        };
505
506        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
507            let model_config: &dyn ModelConfigLike = &model_config_metadata;
508            let cache_config = calculate_cache_config(
509                paged_attn_config.mem_gpu,
510                paged_attn_config.block_size,
511                internal_dtype,
512                paged_attn_config.cache_type,
513                model_config,
514                device,
515                &layer_devices,
516                silent,
517                None,
518                max_kv_tokens,
519            )?;
520            let cache_engine = CacheEngine::new(
521                model_config,
522                &cache_config,
523                internal_dtype,
524                device,
525                layer_devices,
526            )?;
527            (Some(cache_config), Some(cache_engine))
528        } else {
529            (None, None)
530        };
531
532        let gen_conf: Option<GenerationConfig> = paths
533            .get_gen_conf_filename()
534            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
535        let chat_template_explicit = paths
536            .get_chat_template_explicit()
537            .as_ref()
538            .map(|x| x.to_string_lossy().to_string());
539        let mut chat_template = get_chat_template(
540            paths,
541            self.jinja_explicit.as_ref(),
542            chat_template_explicit.as_ref(),
543            self.chat_template.as_ref(),
544            gguf_chat_template,
545        );
546
547        let max_seq_len = match model {
548            Model::Llama(ref l) => l.max_seq_len,
549            Model::Phi2(ref p) => p.max_seq_len,
550            Model::XLoraLlama(ref xl) => xl.max_seq_len,
551            Model::Phi3(ref p) => p.max_seq_len,
552            Model::XLoraPhi3(ref p) => p.max_seq_len,
553            Model::Starcoder2(ref p) => p.max_seq_len,
554            Model::Qwen(ref p) => p.max_seq_len,
555            Model::Qwen3(ref p) => p.max_seq_len,
556            Model::Qwen3MoE(ref p) => p.max_seq_len,
557            Model::Qwen35(ref p) => p.max_seq_len,
558        };
559        let llg_factory = build_llg_factory(tokenizer.clone())?;
560        let num_hidden_layers = match model {
561            Model::Llama(ref model) => model.cache.normal().0.len(),
562            Model::Phi2(ref model) => model.cache.normal().0.len(),
563            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
564            Model::Phi3(ref model) => model.cache.normal().0.len(),
565            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
566            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
567            Model::Qwen(ref model) => model.cache.normal().0.len(),
568            Model::Qwen3(ref model) => model.cache.normal().0.len(),
569            Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
570            Model::Qwen35(ref model) => model.cache.hybrid().num_layers(),
571        };
572
573        if chat_template.bos_token.is_none() {
574            if let Some(v) = bos {
575                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
576            }
577        }
578        if chat_template.eos_token.is_none() {
579            if let Some(v) = eos {
580                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
581            }
582        }
583        if chat_template.unk_token.is_none() {
584            if let Some(v) = unk {
585                chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
586            }
587        }
588
589        let generation_defaults = gen_conf
590            .as_ref()
591            .and_then(GenerationConfig::generation_defaults);
592        let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
593        Ok(Arc::new(Mutex::new(GGUFPipeline {
594            model,
595            tokenizer: tokenizer.into(),
596            no_kv_cache: self.no_kv_cache,
597            chat_template: Arc::new(chat_template),
598            model_id: self
599                .model_id
600                .clone()
601                .unwrap_or(self.quantized_model_id.clone()),
602            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
603                NonGranularState {
604                    non_granular_index: Arc::new(Mutex::new(0)),
605                    tgt_non_granular_index,
606                }
607            }),
608            metadata: Arc::new(GeneralMetadata {
609                max_seq_len,
610                llg_factory: Some(llg_factory),
611                no_kv_cache: self.no_kv_cache,
612                no_prefix_cache: false,
613                num_hidden_layers,
614                eos_tok: eos,
615                kind: self.kind.clone(),
616                is_xlora,
617                activation_dtype: internal_dtype,
618                sliding_window: None,
619                cache_config,
620                cache_engine,
621                model_metadata: Some(Arc::new(model_config_metadata)),
622                modalities: Modalities {
623                    input: vec![SupportedModality::Text],
624                    output: vec![SupportedModality::Text],
625                },
626            }),
627            generation_defaults,
628            mapper: pipeline_mapper,
629        })))
630    }
631
632    fn get_id(&self) -> String {
633        self.xlora_model_id
634            .as_deref()
635            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
636            .to_string()
637    }
638
639    fn get_kind(&self) -> ModelKind {
640        self.kind.clone()
641    }
642}
643
644impl PreProcessingMixin for GGUFPipeline {
645    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
646        Some(self.chat_template.clone())
647    }
648    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
649        None
650    }
651}
652
653impl IsqPipelineMixin for GGUFPipeline {
654    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
655        anyhow::bail!(
656            "You are trying to in-situ requantize a GGML model. This will not do anything."
657        )
658    }
659}
660
661impl CacheManagerMixin for GGUFPipeline {
662    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
663        match self.cache() {
664            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
665            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
666            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
667        }
668    }
669    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
670        match self.cache() {
671            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
672            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
673            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
674        }
675    }
676    fn set_none_cache(
677        &self,
678        seqs: &mut [&mut Sequence],
679        reset_non_granular: bool,
680        modify_draft_cache: bool,
681        load_preallocated_cache: bool,
682    ) {
683        match self.cache() {
684            EitherCache::Full(_) => {
685                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
686            }
687            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
688                self,
689                seqs,
690                modify_draft_cache,
691                load_preallocated_cache,
692            ),
693            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
694                self,
695                seqs,
696                modify_draft_cache,
697                load_preallocated_cache,
698            ),
699        }
700        if reset_non_granular {
701            self.reset_non_granular_state()
702        }
703    }
704    fn cache(&self) -> &EitherCache {
705        match self.model {
706            Model::Llama(ref model) => &model.cache,
707            Model::Phi2(ref model) => &model.cache,
708            Model::XLoraLlama(ref model) => &model.cache,
709            Model::Phi3(ref model) => &model.cache,
710            Model::XLoraPhi3(ref model) => &model.cache,
711            Model::Starcoder2(ref model) => &model.cache,
712            Model::Qwen(ref model) => &model.cache,
713            Model::Qwen3(ref model) => &model.cache,
714            Model::Qwen3MoE(ref model) => &model.cache,
715            Model::Qwen35(ref model) => &model.cache,
716        }
717    }
718}
719
720impl MetadataMixin for GGUFPipeline {
721    fn device(&self) -> Device {
722        match self.model {
723            Model::Llama(ref model) => model.device.clone(),
724            Model::Phi2(ref model) => model.device.clone(),
725            Model::XLoraLlama(ref model) => model.device.clone(),
726            Model::Phi3(ref model) => model.device.clone(),
727            Model::XLoraPhi3(ref model) => model.device.clone(),
728            Model::Starcoder2(ref model) => model.device.clone(),
729            Model::Qwen(ref model) => model.device.clone(),
730            Model::Qwen3(ref model) => model.device.clone(),
731            Model::Qwen3MoE(ref model) => model.device.clone(),
732            Model::Qwen35(ref model) => model.device.clone(),
733        }
734    }
735    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
736        Some(self.tokenizer.clone())
737    }
738    fn name(&self) -> String {
739        self.model_id.clone()
740    }
741    fn reset_non_granular_state(&self) {
742        if let Some(s) = self.non_granular_state.as_ref() {
743            *self.cache().full().get_scalings_cache() = None;
744            *get_mut_arcmutex!(s.non_granular_index) = 0;
745        }
746    }
747    fn get_metadata(&self) -> Arc<GeneralMetadata> {
748        self.metadata.clone()
749    }
750    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
751        self.generation_defaults.clone()
752    }
753    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
754        Some(&*self.mapper)
755    }
756}
757
758#[async_trait::async_trait]
759impl Pipeline for GGUFPipeline {
760    fn forward_inputs(
761        &mut self,
762        inputs: Box<dyn Any>,
763        return_raw_logits: bool,
764    ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
765        let ModelInputs {
766            input_ids,
767            input_ids_full,
768            seqlen_offsets,
769            seqlen_offsets_full,
770            context_lens,
771            position_ids: _, // NOTE(hanzoai): ignore, it is for phi3
772            paged_attn_meta,
773            flash_meta,
774            flash_meta_full,
775        } = *inputs.downcast().expect("Downcast failed.");
776        let metadata = self.get_metadata();
777        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
778            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
779            (Some(_), None) => {
780                // This can happen if Rust-side user code is wrong
781                hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
782            }
783            (None, Some(_)) => {
784                // This should never happen but we handle it anyway
785                hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
786            }
787            (None, None) => None,
788        };
789        let logits = match self.model {
790            Model::Llama(ref model) => {
791                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
792            }
793            Model::Phi2(ref model) => {
794                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
795            }
796            Model::XLoraLlama(ref model) => model.forward(
797                &input_ids,
798                input_ids_full.as_ref().unwrap_or(&input_ids),
799                &seqlen_offsets,
800                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
801                self.no_kv_cache,
802                &self.non_granular_state,
803                context_lens,
804                &flash_meta,
805                flash_meta_full.as_ref().unwrap_or(&flash_meta),
806            )?,
807            Model::Phi3(ref model) => {
808                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
809            }
810            Model::XLoraPhi3(ref model) => model.forward(
811                &input_ids,
812                input_ids_full.as_ref().unwrap_or(&input_ids),
813                &seqlen_offsets,
814                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
815                self.no_kv_cache,
816                &self.non_granular_state,
817                context_lens,
818                &flash_meta,
819                flash_meta_full.as_ref().unwrap_or(&flash_meta),
820            )?,
821            Model::Starcoder2(ref model) => {
822                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
823            }
824            Model::Qwen(ref model) => {
825                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
826            }
827            Model::Qwen3(ref model) => {
828                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
829            }
830            Model::Qwen3MoE(ref model) => {
831                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
832            }
833            Model::Qwen35(ref model) => {
834                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
835            }
836        };
837        if return_raw_logits {
838            Ok(ForwardInputsResult::RawLogits { logits })
839        } else {
840            Ok(ForwardInputsResult::CausalGeneration { logits })
841        }
842    }
843    async fn sample_causal_gen(
844        &self,
845        seqs: &mut [&mut Sequence],
846        logits: Vec<Tensor>,
847        prefix_cacher: &mut PrefixCacheManagerV2,
848        disable_eos_stop: bool,
849        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
850    ) -> Result<(), hanzo_ml::Error> {
851        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
852    }
853    fn category(&self) -> ModelCategory {
854        ModelCategory::Text
855    }
856}
857
858// TODO
859impl AnyMoePipelineMixin for GGUFPipeline {}