1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8 MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::progress::ProgressScopeGuard;
23use crate::utils::tokenizer::get_tokenizer;
24use crate::xlora_models::NonGranularState;
25use crate::{
26 get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
27 TryIntoDType, DEBUG,
28};
29use crate::{
30 models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
31 xlora_models::XLoraQLlama,
32};
33use anyhow::Result;
34use hanzo_ml::quantized::ggml_file;
35use hanzo_ml::{Device, Tensor};
36use hanzo_quant::IsqType;
37use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
38use rand_isaac::Isaac64Rng;
39use std::any::Any;
40use std::fs;
41use std::path::PathBuf;
42use std::str::FromStr;
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tokio::sync::Mutex;
46use tracing::{debug, info, trace, warn};
47
48enum Model {
49 Llama(Box<QLlama>),
50 XLoraLlama(Box<XLoraQLlama>),
51}
52
53pub struct GGMLPipeline {
54 model: Model,
55 tokenizer: Arc<Tokenizer>,
56 no_kv_cache: bool,
57 chat_template: Arc<ChatTemplate>,
58 model_id: String,
59 non_granular_state: Option<NonGranularState>,
60 metadata: Arc<GeneralMetadata>,
61 generation_defaults: Option<crate::ModelGenerationDefaults>,
62}
63
64pub struct GGMLLoader {
66 model_id: String,
67 config: GGMLSpecificConfig,
68 quantized_model_id: Option<String>,
69 quantized_filename: Option<String>,
70 xlora_model_id: Option<String>,
71 xlora_order: Option<Ordering>,
72 no_kv_cache: bool,
73 chat_template: Option<String>,
74 tokenizer_json: Option<String>,
75 kind: ModelKind,
76 tgt_non_granular_index: Option<usize>,
77 jinja_explicit: Option<String>,
78 lora_adapter_ids: Option<Vec<String>>,
79}
80
81#[derive(Clone, Default)]
82pub struct GGMLSpecificConfig {
84 pub gqa: usize,
85 pub topology: Option<Topology>,
86}
87
88#[derive(Default)]
89pub struct GGMLLoaderBuilder {
91 model_id: Option<String>,
92 config: GGMLSpecificConfig,
93 quantized_model_id: String,
94 quantized_filename: String,
95 xlora_model_id: Option<String>,
96 kind: ModelKind,
97 xlora_order: Option<Ordering>,
98 no_kv_cache: bool,
99 chat_template: Option<String>,
100 tokenizer_json: Option<String>,
101 tgt_non_granular_index: Option<usize>,
102 jinja_explicit: Option<String>,
103}
104
105impl GGMLLoaderBuilder {
106 #[allow(clippy::too_many_arguments)]
107 pub fn new(
108 config: GGMLSpecificConfig,
109 chat_template: Option<String>,
110 tokenizer_json: Option<String>,
111 model_id: Option<String>,
112 quantized_model_id: String,
113 quantized_filename: String,
114 no_kv_cache: bool,
115 jinja_explicit: Option<String>,
116 ) -> Self {
117 let kind = ModelKind::GgufQuantized {
118 quant: QuantizationKind::Ggml,
119 };
120
121 Self {
122 config,
123 chat_template,
124 tokenizer_json,
125 model_id,
126 kind,
127 quantized_filename,
128 quantized_model_id,
129 no_kv_cache,
130 jinja_explicit,
131 ..Default::default()
132 }
133 }
134
135 fn with_adapter(
136 mut self,
137 xlora_model_id: String,
138 xlora_order: Ordering,
139 no_kv_cache: bool,
140 tgt_non_granular_index: Option<usize>,
141 ) -> Self {
142 self.xlora_model_id = Some(xlora_model_id);
143 self.xlora_order = Some(xlora_order);
144 self.no_kv_cache = no_kv_cache;
145 self.tgt_non_granular_index = tgt_non_granular_index;
146 self.model_id = if let Some(id) = self.model_id {
147 Some(id)
148 } else {
149 info!(
150 "Using adapter base model ID: `{}`",
151 self.xlora_order.as_ref().unwrap().base_model_id
152 );
153 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
154 };
155 self
156 }
157
158 pub fn with_xlora(
159 mut self,
160 xlora_model_id: String,
161 xlora_order: Ordering,
162 no_kv_cache: bool,
163 tgt_non_granular_index: Option<usize>,
164 ) -> Self {
165 self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
166
167 self.with_adapter(
168 xlora_model_id,
169 xlora_order,
170 no_kv_cache,
171 tgt_non_granular_index,
172 )
173 }
174
175 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
176 self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
177
178 self.with_adapter(lora_model_id, lora_order, false, None)
179 }
180
181 pub fn build(self) -> Box<dyn Loader> {
182 Box::new(GGMLLoader {
183 model_id: self.model_id.unwrap(),
184 config: self.config,
185 xlora_model_id: self.xlora_model_id,
186 kind: self.kind,
187 xlora_order: self.xlora_order,
188 no_kv_cache: self.no_kv_cache,
189 chat_template: self.chat_template,
190 tokenizer_json: self.tokenizer_json,
191 tgt_non_granular_index: self.tgt_non_granular_index,
192 quantized_filename: Some(self.quantized_filename),
193 quantized_model_id: Some(self.quantized_model_id),
194 jinja_explicit: self.jinja_explicit,
195 lora_adapter_ids: None,
196 })
197 }
198}
199
200impl GGMLLoader {
201 #[allow(clippy::too_many_arguments)]
202 pub fn new(
203 model_id: Option<String>,
204 config: GGMLSpecificConfig,
205 quantized_model_id: Option<String>,
206 quantized_filename: Option<String>,
207 xlora_model_id: Option<String>,
208 kind: ModelKind,
209 xlora_order: Option<Ordering>,
210 no_kv_cache: bool,
211 chat_template: Option<String>,
212 tokenizer_json: Option<String>,
213 tgt_non_granular_index: Option<usize>,
214 jinja_explicit: Option<String>,
215 ) -> Self {
216 let model_id = if let Some(id) = model_id {
217 id
218 } else {
219 info!(
220 "Using adapter base model ID: `{}`",
221 xlora_order.as_ref().unwrap().base_model_id
222 );
223 xlora_order.as_ref().unwrap().base_model_id.clone()
224 };
225 Self {
226 model_id,
227 config,
228 quantized_model_id,
229 quantized_filename,
230 xlora_model_id,
231 xlora_order,
232 no_kv_cache,
233 chat_template,
234 tokenizer_json,
235 kind,
236 tgt_non_granular_index,
237 jinja_explicit,
238 lora_adapter_ids: None,
239 }
240 }
241}
242
243impl Loader for GGMLLoader {
244 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
245 fn load_model_from_path(
246 &self,
247 paths: &Box<dyn ModelPaths>,
248 dtype: &dyn TryIntoDType,
249 device: &Device,
250 silent: bool,
251 mapper: DeviceMapSetting,
252 in_situ_quant: Option<IsqType>,
253 mut paged_attn_config: Option<PagedAttentionConfig>,
254 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
255 let _progress_guard = ProgressScopeGuard::new(silent);
256 if in_situ_quant.is_some() {
257 anyhow::bail!(
258 "You are trying to in-situ quantize a GGML model. This will not do anything."
259 );
260 }
261
262 if matches!(mapper, DeviceMapSetting::Map(_)) {
263 anyhow::bail!("Device mapping is not supported for diffusion models.")
264 }
265
266 if paged_attn_config.is_some() {
267 warn!("PagedAttention is not supported for GGML models, disabling it.");
268
269 paged_attn_config = None;
270 }
271
272 debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
273
274 info!(
275 "Loading model `{}` on {}.",
276 self.get_id(),
277 device.device_pretty_repr()
278 );
279
280 #[cfg(feature = "cuda")]
281 if let Device::Cuda(dev) = &device {
282 unsafe { dev.disable_event_tracking() };
283 }
284
285 let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
286 let model = ggml_file::Content::read(&mut file, device)
287 .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
288
289 trace!("Model config: {:?}", model.hparams);
290
291 if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
292 let mut tensors = Vec::new();
293 for (name, t) in &model.tensors {
294 tensors.push(format!(
295 "name = `{name}`, shape = {:?}, dtype = {:?}",
296 t.shape().clone(),
297 t.dtype(),
298 ));
299 }
300 fs::write(
301 "hanzo_ggml_tensors.txt",
302 serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
303 )?;
304
305 info!("Debug is enabled, wrote the names and information about each tensor to `hanzo_ggml_tensors.txt`.");
306 }
307
308 let _ = if paged_attn_config.is_none() {
309 warn!("GGML does not currently support PagedAttention, running without");
310 None
311 } else {
312 paged_attn_config
313 };
314
315 let has_adapter = self.kind.is_adapted();
316 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
317 let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
318
319 let model_config = {
320 let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
322
323 let mut adapter = None;
325 if has_adapter {
326 adapter.replace(ModelConfig::Adapter::try_new(
327 paths, device, silent, is_xlora,
328 )?);
329 }
330
331 ModelConfig::ModelParams::new(quant, adapter)
332 };
333
334 let model = match self.kind {
337 ModelKind::GgufQuantized { .. } => {
338 Model::Llama(Box::new(QLlama::try_from(model_config)?))
339 }
340 ModelKind::GgufAdapter { .. } => {
341 Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
342 }
343 _ => unreachable!(),
344 };
345
346 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
347 let gen_conf: Option<GenerationConfig> = paths
348 .get_gen_conf_filename()
349 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
350 let chat_template_explicit = paths
351 .get_chat_template_explicit()
352 .as_ref()
353 .map(|x| x.to_string_lossy().to_string());
354 let chat_template = get_chat_template(
355 paths,
356 self.jinja_explicit.as_ref(),
357 chat_template_explicit.as_ref(),
358 self.chat_template.as_ref(),
359 None,
360 );
361
362 let max_seq_len = match model {
363 Model::Llama(ref l) => l.max_seq_len,
364 Model::XLoraLlama(ref xl) => xl.max_seq_len,
365 };
366 let llg_factory = build_llg_factory(tokenizer.clone())?;
367 let num_hidden_layers = match model {
368 Model::Llama(ref model) => model.cache.normal().0.len(),
369 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
370 };
371 let generation_defaults = gen_conf
372 .as_ref()
373 .and_then(GenerationConfig::generation_defaults);
374 let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
375 Ok(Arc::new(Mutex::new(GGMLPipeline {
376 model,
377 tokenizer: tokenizer.into(),
378 no_kv_cache: self.no_kv_cache,
379 chat_template: Arc::new(chat_template),
380 model_id: self.model_id.clone(),
381 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
382 NonGranularState {
383 non_granular_index: Arc::new(Mutex::new(0)),
384 tgt_non_granular_index,
385 }
386 }),
387 metadata: Arc::new(GeneralMetadata {
388 max_seq_len,
389 llg_factory: Some(llg_factory),
390 no_kv_cache: self.no_kv_cache,
391 no_prefix_cache: false,
392 num_hidden_layers,
393 eos_tok: eos,
394 kind: self.kind.clone(),
395 is_xlora,
396 activation_dtype: internal_dtype,
397 sliding_window: None,
398 cache_config: None,
399 cache_engine: None,
400 model_metadata: None,
401 modalities: Modalities {
402 input: vec![SupportedModality::Text],
403 output: vec![SupportedModality::Text],
404 },
405 }),
406 generation_defaults,
407 })))
408 }
409
410 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
411 fn load_model_from_hf(
412 &self,
413 revision: Option<String>,
414 token_source: TokenSource,
415 dtype: &dyn TryIntoDType,
416 device: &Device,
417 silent: bool,
418 mapper: DeviceMapSetting,
419 in_situ_quant: Option<IsqType>,
420 paged_attn_config: Option<PagedAttentionConfig>,
421 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
422 let _progress_guard = ProgressScopeGuard::new(silent);
423 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
424 LocalModelPaths,
425 &token_source,
426 revision,
427 self,
428 self.quantized_model_id,
429 Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
430 silent,
431 false );
433 self.load_model_from_path(
434 &paths?,
435 dtype,
436 device,
437 silent,
438 mapper,
439 in_situ_quant,
440 paged_attn_config,
441 )
442 }
443
444 fn get_id(&self) -> String {
445 self.xlora_model_id
446 .as_deref()
447 .unwrap_or(&self.model_id)
448 .to_string()
449 }
450
451 fn get_kind(&self) -> ModelKind {
452 self.kind.clone()
453 }
454}
455
456impl PreProcessingMixin for GGMLPipeline {
457 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
458 Some(self.chat_template.clone())
459 }
460 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
461 None
462 }
463}
464
465impl IsqPipelineMixin for GGMLPipeline {
466 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
467 anyhow::bail!(
468 "You are trying to in-situ requantize a GGML model. This will not do anything."
469 )
470 }
471}
472
473impl CacheManagerMixin for GGMLPipeline {
474 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
475 FullCacheManager.clone_in_cache(self, seqs, false)
476 }
477 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
478 FullCacheManager.clone_out_cache(self, seqs, false)
479 }
480 fn set_none_cache(
481 &self,
482 seqs: &mut [&mut Sequence],
483 reset_non_granular: bool,
484 modify_draft_cache: bool,
485
486 load_preallocated_cache: bool,
487 ) {
488 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
489 if reset_non_granular {
490 self.reset_non_granular_state()
491 }
492 }
493 fn cache(&self) -> &EitherCache {
494 match self.model {
495 Model::Llama(ref model) => &model.cache,
496 Model::XLoraLlama(ref model) => &model.cache,
497 }
498 }
499}
500
501impl MetadataMixin for GGMLPipeline {
502 fn device(&self) -> Device {
503 match self.model {
504 Model::Llama(ref model) => model.device.clone(),
505 Model::XLoraLlama(ref model) => model.device.clone(),
506 }
507 }
508 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
509 Some(self.tokenizer.clone())
510 }
511 fn name(&self) -> String {
512 self.model_id.clone()
513 }
514 fn reset_non_granular_state(&self) {
515 if let Some(s) = self.non_granular_state.as_ref() {
516 *self.cache().full().get_scalings_cache() = None;
517 *get_mut_arcmutex!(s.non_granular_index) = 0;
518 }
519 }
520 fn get_metadata(&self) -> Arc<GeneralMetadata> {
521 self.metadata.clone()
522 }
523 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
524 self.generation_defaults.clone()
525 }
526 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
527 None
528 }
529}
530
531#[async_trait::async_trait]
532impl Pipeline for GGMLPipeline {
533 fn forward_inputs(
534 &mut self,
535 inputs: Box<dyn Any>,
536 return_raw_logits: bool,
537 ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
538 let ModelInputs {
539 input_ids,
540 input_ids_full,
541 seqlen_offsets,
542 seqlen_offsets_full,
543 context_lens,
544 position_ids: _, paged_attn_meta: _, flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed.");
549 let logits = match self.model {
550 Model::Llama(ref model) => {
551 model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
552 }
553 Model::XLoraLlama(ref model) => model.forward(
554 &input_ids,
555 input_ids_full.as_ref().unwrap_or(&input_ids),
556 &seqlen_offsets,
557 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
558 self.no_kv_cache,
559 &self.non_granular_state,
560 context_lens,
561 &flash_meta,
562 flash_meta_full.as_ref().unwrap_or(&flash_meta),
563 )?,
564 };
565 if return_raw_logits {
566 Ok(ForwardInputsResult::RawLogits { logits })
567 } else {
568 Ok(ForwardInputsResult::CausalGeneration { logits })
569 }
570 }
571 async fn sample_causal_gen(
572 &self,
573 seqs: &mut [&mut Sequence],
574 logits: Vec<Tensor>,
575 prefix_cacher: &mut PrefixCacheManagerV2,
576 disable_eos_stop: bool,
577 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
578 ) -> Result<(), hanzo_ml::Error> {
579 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
580 }
581 fn category(&self) -> ModelCategory {
582 ModelCategory::Text
583 }
584}
585
586impl AnyMoePipelineMixin for GGMLPipeline {}