1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5 TokenSource,
6};
7use super::{
8 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9 MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::distributed::WorkerTransferData;
14use crate::gguf::{
15 get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
16};
17use crate::gguf::{Content, GGUFArchitecture};
18use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
19use crate::lora::Ordering;
20use crate::paged_attention::{
21 calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
22};
23use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
24use crate::pipeline::loaders::DeviceMappedModelLoader;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::ChatTemplate;
27use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
28use crate::prefix_cacher::PrefixCacheManagerV2;
29use crate::sequence::Sequence;
30use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
31use crate::utils::model_config as ModelConfig;
32use crate::utils::progress::ProgressScopeGuard;
33use crate::utils::tokenizer::get_tokenizer;
34use crate::xlora_models::NonGranularState;
35use crate::{
36 distributed, get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths,
37 PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
38};
39use crate::{
40 models::quantized_llama::ModelWeights as QLlama,
41 models::quantized_phi2::ModelWeights as QPhi,
42 models::quantized_phi3::ModelWeights as QPhi3,
43 models::quantized_qwen::ModelWeights as QQwen,
44 models::quantized_qwen3::ModelWeights as QQwen3,
45 models::quantized_qwen3_5_moe::ModelWeights as QQwen35,
46 models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
47 models::quantized_starcoder2::ModelWeights as QStarcoder2,
48 utils::tokens::get_token,
49 xlora_models::{XLoraQLlama, XLoraQPhi3},
50};
51use anyhow::{bail, Result};
52use either::Either;
53use hanzo_ml::{Device, Tensor};
54use hanzo_quant::IsqType;
55use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
56use rand_isaac::Isaac64Rng;
57use std::any::Any;
58use std::path::PathBuf;
59use std::str::FromStr;
60use std::sync::Arc;
61use std::{env, fs};
62use tokenizers::Tokenizer;
63use tokio::sync::Mutex;
64use tracing::{debug, info, warn};
65
66enum Model {
67 Llama(QLlama),
68 Phi2(QPhi),
69 XLoraLlama(XLoraQLlama),
70 XLoraPhi3(XLoraQPhi3),
71 Phi3(QPhi3),
72 Starcoder2(QStarcoder2),
73 Qwen(QQwen),
74 Qwen3(QQwen3),
75 Qwen3MoE(QQwen3MoE),
76 Qwen35(QQwen35),
77}
78
79pub struct GGUFPipeline {
80 model: Model,
81 tokenizer: Arc<Tokenizer>,
82 no_kv_cache: bool,
83 chat_template: Arc<ChatTemplate>,
84 model_id: String,
85 non_granular_state: Option<NonGranularState>,
86 metadata: Arc<GeneralMetadata>,
87 generation_defaults: Option<crate::ModelGenerationDefaults>,
88 mapper: Box<dyn DeviceMapper + Send + Sync>,
89}
90
91pub struct GGUFLoader {
93 model_id: Option<String>,
94 quantized_model_id: String,
95 quantized_filenames: Vec<String>,
96 xlora_model_id: Option<String>,
97 xlora_order: Option<Ordering>,
98 no_kv_cache: bool,
99 chat_template: Option<String>,
100 kind: ModelKind,
101 tgt_non_granular_index: Option<usize>,
102 config: GGUFSpecificConfig,
103 jinja_explicit: Option<String>,
104 lora_adapter_ids: Option<Vec<String>>,
105}
106
107#[derive(Clone, Default)]
108pub struct GGUFSpecificConfig {
110 pub topology: Option<Topology>,
111}
112
113#[derive(Default)]
114pub struct GGUFLoaderBuilder {
116 model_id: Option<String>,
117 quantized_model_id: String,
118 quantized_filenames: Vec<String>,
119 xlora_model_id: Option<String>,
120 kind: ModelKind,
121 xlora_order: Option<Ordering>,
122 no_kv_cache: bool,
123 chat_template: Option<String>,
124 tgt_non_granular_index: Option<usize>,
125 config: GGUFSpecificConfig,
126 jinja_explicit: Option<String>,
127}
128
129impl GGUFLoaderBuilder {
130 pub fn new(
134 chat_template: Option<String>,
135 tok_model_id: Option<String>,
136 quantized_model_id: String,
137 quantized_filenames: Vec<String>,
138 config: GGUFSpecificConfig,
139 no_kv_cache: bool,
140 jinja_explicit: Option<String>,
141 ) -> Self {
142 let kind = ModelKind::GgufQuantized {
143 quant: QuantizationKind::Gguf,
144 };
145
146 Self {
147 chat_template,
148 model_id: tok_model_id,
149 kind,
150 quantized_filenames,
151 quantized_model_id,
152 config,
153 jinja_explicit,
154 no_kv_cache,
155 ..Default::default()
156 }
157 }
158
159 fn with_adapter(
160 mut self,
161 xlora_model_id: String,
162 xlora_order: Ordering,
163 no_kv_cache: bool,
164 tgt_non_granular_index: Option<usize>,
165 ) -> Self {
166 self.xlora_model_id = Some(xlora_model_id);
167 self.xlora_order = Some(xlora_order);
168 self.no_kv_cache = no_kv_cache;
169 self.tgt_non_granular_index = tgt_non_granular_index;
170 self.model_id = if let Some(id) = self.model_id {
171 Some(id)
172 } else {
173 info!(
174 "Using adapter base model ID: `{}`",
175 self.xlora_order.as_ref().unwrap().base_model_id
176 );
177 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
178 };
179 self
180 }
181
182 pub fn with_xlora(
183 mut self,
184 xlora_model_id: String,
185 xlora_order: Ordering,
186 no_kv_cache: bool,
187 tgt_non_granular_index: Option<usize>,
188 ) -> Self {
189 self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
190
191 self.with_adapter(
192 xlora_model_id,
193 xlora_order,
194 no_kv_cache,
195 tgt_non_granular_index,
196 )
197 }
198
199 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
200 self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
201
202 self.with_adapter(lora_model_id, lora_order, false, None)
203 }
204
205 pub fn build(self) -> Box<dyn Loader> {
206 Box::new(GGUFLoader {
207 model_id: self.model_id,
208 xlora_model_id: self.xlora_model_id,
209 kind: self.kind,
210 xlora_order: self.xlora_order,
211 no_kv_cache: self.no_kv_cache,
212 chat_template: self.chat_template,
213 tgt_non_granular_index: self.tgt_non_granular_index,
214 quantized_filenames: self.quantized_filenames,
215 quantized_model_id: self.quantized_model_id,
216 config: self.config,
217 jinja_explicit: self.jinja_explicit,
218 lora_adapter_ids: None,
219 })
220 }
221}
222
223impl GGUFLoader {
224 #[allow(clippy::too_many_arguments)]
225 pub fn new(
226 model_id: Option<String>,
227 quantized_model_id: String,
228 quantized_filenames: Vec<String>,
229 xlora_model_id: Option<String>,
230 kind: ModelKind,
231 xlora_order: Option<Ordering>,
232 no_kv_cache: bool,
233 chat_template: Option<String>,
234 tgt_non_granular_index: Option<usize>,
235 config: GGUFSpecificConfig,
236 jinja_explicit: Option<String>,
237 ) -> Self {
238 let model_id = if let Some(id) = model_id {
239 Some(id)
240 } else if let Some(xlora_order) = xlora_order.clone() {
241 info!(
242 "Using adapter base model ID: `{}`",
243 xlora_order.base_model_id
244 );
245 Some(xlora_order.base_model_id.clone())
246 } else {
247 None
248 };
249 Self {
250 model_id,
251 quantized_model_id,
252 quantized_filenames,
253 xlora_model_id,
254 xlora_order,
255 no_kv_cache,
256 chat_template,
257 kind,
258 tgt_non_granular_index,
259 config,
260 jinja_explicit,
261 lora_adapter_ids: None,
262 }
263 }
264}
265
266impl Loader for GGUFLoader {
267 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
268 fn load_model_from_hf(
269 &self,
270 revision: Option<String>,
271 token_source: TokenSource,
272 dtype: &dyn TryIntoDType,
273 device: &Device,
274 silent: bool,
275 mapper: DeviceMapSetting,
276 in_situ_quant: Option<IsqType>,
277 paged_attn_config: Option<PagedAttentionConfig>,
278 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
279 let _progress_guard = ProgressScopeGuard::new(silent);
280 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
281 LocalModelPaths,
282 &token_source,
283 revision,
284 self,
285 self.quantized_model_id.clone(),
286 self.quantized_filenames.clone(),
287 silent
288 );
289
290 self.load_model_from_path(
291 &paths?,
292 dtype,
293 device,
294 silent,
295 mapper,
296 in_situ_quant,
297 paged_attn_config,
298 )
299 }
300
301 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
302 fn load_model_from_path(
303 &self,
304 paths: &Box<dyn ModelPaths>,
305 dtype: &dyn TryIntoDType,
306 device: &Device,
307 silent: bool,
308 mut mapper: DeviceMapSetting,
309 in_situ_quant: Option<IsqType>,
310 mut paged_attn_config: Option<PagedAttentionConfig>,
311 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
312 let _progress_guard = ProgressScopeGuard::new(silent);
313 if in_situ_quant.is_some() {
314 anyhow::bail!(
315 "You are trying to in-situ quantize a GGUF model. This will not do anything."
316 );
317 }
318
319 debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
320
321 let mut readers = Vec::new();
322 for filename in paths.get_weight_filenames() {
323 readers.push(std::fs::File::open(filename)?);
324 }
325 let mut readers = readers.iter_mut().collect::<Vec<_>>();
326 let model = Content::from_readers(&mut readers)?;
327
328 if !silent {
329 model.print_metadata()?;
330 }
331
332 let arch = model.arch();
333
334 let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
336
337 let mut max_kv_tokens: Option<usize> = None;
338
339 if let DeviceMapSetting::Auto(params) = mapper.clone() {
340 let devices = device_map::get_all_similar_devices(device)?;
341 let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
343
344 let model = GgufDeviceMapLoaderInner {
345 model: &model,
346 arch,
347 };
348
349 let layer_sizes_in_bytes =
350 model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
351 let non_mapped_size_in_bytes =
352 model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
353 let total_model_size_in_bytes =
354 layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
355
356 let new = model.get_device_layers(
357 "this is a dummy config!",
358 num_layers,
359 layer_sizes_in_bytes,
360 non_mapped_size_in_bytes,
361 total_model_size_in_bytes,
362 &devices,
363 dtype,
364 ¶ms,
365 paged_attn_config.as_ref(),
366 )?;
367 max_kv_tokens = Some(params.max_seq_len() * params.max_batch_size());
368 mapper = DeviceMapSetting::Map(new);
369 }
370
371 #[cfg(feature = "cuda")]
372 if let Device::Cuda(dev) = &device {
373 unsafe { dev.disable_event_tracking() };
374 }
375
376 let use_nccl = hanzo_quant::distributed::use_nccl();
377 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
378 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
379 let WorkerTransferData::Init { id: _, worker_rank } = payload;
380 vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
381 } else if use_nccl {
382 vec![hanzo_ml::Device::new_cuda(0)?]
383 } else {
384 device_map::get_all_similar_devices(device)?
385 };
386
387 let pipeline_mapper = mapper.into_mapper(
388 num_layers,
389 device,
390 self.config.topology.as_ref(),
391 &available_devices,
392 )?;
393 let mapper = mapper.into_mapper(
394 num_layers,
395 device,
396 self.config.topology.as_ref(),
397 &available_devices,
398 )?;
399 let mut layer_devices = Vec::new();
400 for layer in 0..num_layers {
401 let device = mapper.device_for(layer, false).cloned();
402 layer_devices.push(device);
403 }
404
405 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
408 if mapping_uses_cpu {
409 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
410 paged_attn_config = None;
411 }
412
413 let GgufTokenizerConversion {
414 tokenizer,
415 bos,
416 eos,
417 unk,
418 } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
419 convert_gguf_to_hf_tokenizer(&model)?
420 } else {
421 GgufTokenizerConversion {
422 tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
423 bos: None,
424 eos: None,
425 unk: None,
426 }
427 };
428
429 let gguf_chat_template =
431 if paths.get_template_filename().is_none() && self.chat_template.is_none() {
432 get_gguf_chat_template(&model)?
433 } else {
434 None
435 };
436
437 let has_adapter = self.kind.is_adapted();
438 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
439
440 let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
441 warn!("Adapter models do not currently support PagedAttention, running without");
442 None
443 } else {
444 paged_attn_config
445 };
446
447 let model_config_metadata: ContentConfig = (&model).into();
448 let internal_dtype = mapper.get_min_dtype(dtype)?;
449
450 let model_config = {
451 let quant = ModelConfig::ParamsGGUF(
453 model,
454 (device, mapper).into(),
455 if paged_attn_config.is_some() {
456 AttentionImplementation::PagedAttention
457 } else {
458 AttentionImplementation::Eager
459 },
460 internal_dtype,
461 );
462
463 let mut adapter = None;
465 if has_adapter {
466 adapter.replace(ModelConfig::Adapter::try_new(
467 paths, device, silent, is_xlora,
468 )?);
469 }
470
471 ModelConfig::ModelParams::new(quant, adapter)
472 };
473
474 let model = match self.kind {
476 ModelKind::GgufQuantized { .. } => match arch {
477 GGUFArchitecture::Llama | GGUFArchitecture::Mistral3 => {
478 Model::Llama(QLlama::try_from(model_config)?)
479 }
480 GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
481 GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
482 GGUFArchitecture::Starcoder2 => {
483 Model::Starcoder2(QStarcoder2::try_from(model_config)?)
484 }
485 GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
486 GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
487 GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
488 GGUFArchitecture::Qwen35 | GGUFArchitecture::Qwen35MoE => {
489 Model::Qwen35(QQwen35::try_from(model_config)?)
490 }
491 a => bail!("Unsupported architecture `{a:?}` for GGUF"),
492 },
493 ModelKind::GgufAdapter { adapter, .. } => match arch {
494 GGUFArchitecture::Llama | GGUFArchitecture::Mistral3 => {
495 Model::XLoraLlama(XLoraQLlama::try_from(model_config)?)
496 }
497 GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
498 a => bail!(
499 "Unsupported architecture `{a:?}` for GGUF {kind}",
500 kind = adapter.pretty_name()
501 ),
502 },
503 _ => unreachable!(),
504 };
505
506 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
507 let model_config: &dyn ModelConfigLike = &model_config_metadata;
508 let cache_config = calculate_cache_config(
509 paged_attn_config.mem_gpu,
510 paged_attn_config.block_size,
511 internal_dtype,
512 paged_attn_config.cache_type,
513 model_config,
514 device,
515 &layer_devices,
516 silent,
517 None,
518 max_kv_tokens,
519 )?;
520 let cache_engine = CacheEngine::new(
521 model_config,
522 &cache_config,
523 internal_dtype,
524 device,
525 layer_devices,
526 )?;
527 (Some(cache_config), Some(cache_engine))
528 } else {
529 (None, None)
530 };
531
532 let gen_conf: Option<GenerationConfig> = paths
533 .get_gen_conf_filename()
534 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
535 let chat_template_explicit = paths
536 .get_chat_template_explicit()
537 .as_ref()
538 .map(|x| x.to_string_lossy().to_string());
539 let mut chat_template = get_chat_template(
540 paths,
541 self.jinja_explicit.as_ref(),
542 chat_template_explicit.as_ref(),
543 self.chat_template.as_ref(),
544 gguf_chat_template,
545 );
546
547 let max_seq_len = match model {
548 Model::Llama(ref l) => l.max_seq_len,
549 Model::Phi2(ref p) => p.max_seq_len,
550 Model::XLoraLlama(ref xl) => xl.max_seq_len,
551 Model::Phi3(ref p) => p.max_seq_len,
552 Model::XLoraPhi3(ref p) => p.max_seq_len,
553 Model::Starcoder2(ref p) => p.max_seq_len,
554 Model::Qwen(ref p) => p.max_seq_len,
555 Model::Qwen3(ref p) => p.max_seq_len,
556 Model::Qwen3MoE(ref p) => p.max_seq_len,
557 Model::Qwen35(ref p) => p.max_seq_len,
558 };
559 let llg_factory = build_llg_factory(tokenizer.clone())?;
560 let num_hidden_layers = match model {
561 Model::Llama(ref model) => model.cache.normal().0.len(),
562 Model::Phi2(ref model) => model.cache.normal().0.len(),
563 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
564 Model::Phi3(ref model) => model.cache.normal().0.len(),
565 Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
566 Model::Starcoder2(ref model) => model.cache.normal().0.len(),
567 Model::Qwen(ref model) => model.cache.normal().0.len(),
568 Model::Qwen3(ref model) => model.cache.normal().0.len(),
569 Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
570 Model::Qwen35(ref model) => model.cache.hybrid().num_layers(),
571 };
572
573 if chat_template.bos_token.is_none() {
574 if let Some(v) = bos {
575 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
576 }
577 }
578 if chat_template.eos_token.is_none() {
579 if let Some(v) = eos {
580 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
581 }
582 }
583 if chat_template.unk_token.is_none() {
584 if let Some(v) = unk {
585 chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
586 }
587 }
588
589 let generation_defaults = gen_conf
590 .as_ref()
591 .and_then(GenerationConfig::generation_defaults);
592 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
593 Ok(Arc::new(Mutex::new(GGUFPipeline {
594 model,
595 tokenizer: tokenizer.into(),
596 no_kv_cache: self.no_kv_cache,
597 chat_template: Arc::new(chat_template),
598 model_id: self
599 .model_id
600 .clone()
601 .unwrap_or(self.quantized_model_id.clone()),
602 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
603 NonGranularState {
604 non_granular_index: Arc::new(Mutex::new(0)),
605 tgt_non_granular_index,
606 }
607 }),
608 metadata: Arc::new(GeneralMetadata {
609 max_seq_len,
610 llg_factory: Some(llg_factory),
611 no_kv_cache: self.no_kv_cache,
612 no_prefix_cache: false,
613 num_hidden_layers,
614 eos_tok: eos,
615 kind: self.kind.clone(),
616 is_xlora,
617 activation_dtype: internal_dtype,
618 sliding_window: None,
619 cache_config,
620 cache_engine,
621 model_metadata: Some(Arc::new(model_config_metadata)),
622 modalities: Modalities {
623 input: vec![SupportedModality::Text],
624 output: vec![SupportedModality::Text],
625 },
626 }),
627 generation_defaults,
628 mapper: pipeline_mapper,
629 })))
630 }
631
632 fn get_id(&self) -> String {
633 self.xlora_model_id
634 .as_deref()
635 .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
636 .to_string()
637 }
638
639 fn get_kind(&self) -> ModelKind {
640 self.kind.clone()
641 }
642}
643
644impl PreProcessingMixin for GGUFPipeline {
645 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
646 Some(self.chat_template.clone())
647 }
648 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
649 None
650 }
651}
652
653impl IsqPipelineMixin for GGUFPipeline {
654 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
655 anyhow::bail!(
656 "You are trying to in-situ requantize a GGML model. This will not do anything."
657 )
658 }
659}
660
661impl CacheManagerMixin for GGUFPipeline {
662 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
663 match self.cache() {
664 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
665 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
666 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
667 }
668 }
669 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
670 match self.cache() {
671 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
672 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
673 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
674 }
675 }
676 fn set_none_cache(
677 &self,
678 seqs: &mut [&mut Sequence],
679 reset_non_granular: bool,
680 modify_draft_cache: bool,
681 load_preallocated_cache: bool,
682 ) {
683 match self.cache() {
684 EitherCache::Full(_) => {
685 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
686 }
687 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
688 self,
689 seqs,
690 modify_draft_cache,
691 load_preallocated_cache,
692 ),
693 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
694 self,
695 seqs,
696 modify_draft_cache,
697 load_preallocated_cache,
698 ),
699 }
700 if reset_non_granular {
701 self.reset_non_granular_state()
702 }
703 }
704 fn cache(&self) -> &EitherCache {
705 match self.model {
706 Model::Llama(ref model) => &model.cache,
707 Model::Phi2(ref model) => &model.cache,
708 Model::XLoraLlama(ref model) => &model.cache,
709 Model::Phi3(ref model) => &model.cache,
710 Model::XLoraPhi3(ref model) => &model.cache,
711 Model::Starcoder2(ref model) => &model.cache,
712 Model::Qwen(ref model) => &model.cache,
713 Model::Qwen3(ref model) => &model.cache,
714 Model::Qwen3MoE(ref model) => &model.cache,
715 Model::Qwen35(ref model) => &model.cache,
716 }
717 }
718}
719
720impl MetadataMixin for GGUFPipeline {
721 fn device(&self) -> Device {
722 match self.model {
723 Model::Llama(ref model) => model.device.clone(),
724 Model::Phi2(ref model) => model.device.clone(),
725 Model::XLoraLlama(ref model) => model.device.clone(),
726 Model::Phi3(ref model) => model.device.clone(),
727 Model::XLoraPhi3(ref model) => model.device.clone(),
728 Model::Starcoder2(ref model) => model.device.clone(),
729 Model::Qwen(ref model) => model.device.clone(),
730 Model::Qwen3(ref model) => model.device.clone(),
731 Model::Qwen3MoE(ref model) => model.device.clone(),
732 Model::Qwen35(ref model) => model.device.clone(),
733 }
734 }
735 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
736 Some(self.tokenizer.clone())
737 }
738 fn name(&self) -> String {
739 self.model_id.clone()
740 }
741 fn reset_non_granular_state(&self) {
742 if let Some(s) = self.non_granular_state.as_ref() {
743 *self.cache().full().get_scalings_cache() = None;
744 *get_mut_arcmutex!(s.non_granular_index) = 0;
745 }
746 }
747 fn get_metadata(&self) -> Arc<GeneralMetadata> {
748 self.metadata.clone()
749 }
750 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
751 self.generation_defaults.clone()
752 }
753 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
754 Some(&*self.mapper)
755 }
756}
757
758#[async_trait::async_trait]
759impl Pipeline for GGUFPipeline {
760 fn forward_inputs(
761 &mut self,
762 inputs: Box<dyn Any>,
763 return_raw_logits: bool,
764 ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
765 let ModelInputs {
766 input_ids,
767 input_ids_full,
768 seqlen_offsets,
769 seqlen_offsets_full,
770 context_lens,
771 position_ids: _, paged_attn_meta,
773 flash_meta,
774 flash_meta_full,
775 } = *inputs.downcast().expect("Downcast failed.");
776 let metadata = self.get_metadata();
777 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
778 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
779 (Some(_), None) => {
780 hanzo_ml::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
782 }
783 (None, Some(_)) => {
784 hanzo_ml::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
786 }
787 (None, None) => None,
788 };
789 let logits = match self.model {
790 Model::Llama(ref model) => {
791 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
792 }
793 Model::Phi2(ref model) => {
794 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
795 }
796 Model::XLoraLlama(ref model) => model.forward(
797 &input_ids,
798 input_ids_full.as_ref().unwrap_or(&input_ids),
799 &seqlen_offsets,
800 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
801 self.no_kv_cache,
802 &self.non_granular_state,
803 context_lens,
804 &flash_meta,
805 flash_meta_full.as_ref().unwrap_or(&flash_meta),
806 )?,
807 Model::Phi3(ref model) => {
808 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
809 }
810 Model::XLoraPhi3(ref model) => model.forward(
811 &input_ids,
812 input_ids_full.as_ref().unwrap_or(&input_ids),
813 &seqlen_offsets,
814 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
815 self.no_kv_cache,
816 &self.non_granular_state,
817 context_lens,
818 &flash_meta,
819 flash_meta_full.as_ref().unwrap_or(&flash_meta),
820 )?,
821 Model::Starcoder2(ref model) => {
822 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
823 }
824 Model::Qwen(ref model) => {
825 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
826 }
827 Model::Qwen3(ref model) => {
828 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
829 }
830 Model::Qwen3MoE(ref model) => {
831 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
832 }
833 Model::Qwen35(ref model) => {
834 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
835 }
836 };
837 if return_raw_logits {
838 Ok(ForwardInputsResult::RawLogits { logits })
839 } else {
840 Ok(ForwardInputsResult::CausalGeneration { logits })
841 }
842 }
843 async fn sample_causal_gen(
844 &self,
845 seqs: &mut [&mut Sequence],
846 logits: Vec<Tensor>,
847 prefix_cacher: &mut PrefixCacheManagerV2,
848 disable_eos_stop: bool,
849 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
850 ) -> Result<(), hanzo_ml::Error> {
851 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
852 }
853 fn category(&self) -> ModelCategory {
854 ModelCategory::Text
855 }
856}
857
858impl AnyMoePipelineMixin for GGUFPipeline {}