Skip to main content

hanzo_engine/pipeline/
ggml.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8    MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::progress::ProgressScopeGuard;
23use crate::utils::tokenizer::get_tokenizer;
24use crate::xlora_models::NonGranularState;
25use crate::{
26    get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
27    TryIntoDType, DEBUG,
28};
29use crate::{
30    models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
31    xlora_models::XLoraQLlama,
32};
33use anyhow::Result;
34use hanzo_ml::quantized::ggml_file;
35use hanzo_ml::{Device, Tensor};
36use hanzo_quant::IsqType;
37use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
38use rand_isaac::Isaac64Rng;
39use std::any::Any;
40use std::fs;
41use std::path::PathBuf;
42use std::str::FromStr;
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tokio::sync::Mutex;
46use tracing::{debug, info, trace, warn};
47
48enum Model {
49    Llama(Box<QLlama>),
50    XLoraLlama(Box<XLoraQLlama>),
51}
52
53pub struct GGMLPipeline {
54    model: Model,
55    tokenizer: Arc<Tokenizer>,
56    no_kv_cache: bool,
57    chat_template: Arc<ChatTemplate>,
58    model_id: String,
59    non_granular_state: Option<NonGranularState>,
60    metadata: Arc<GeneralMetadata>,
61    generation_defaults: Option<crate::ModelGenerationDefaults>,
62}
63
64/// A loader for a GGML model.
65pub struct GGMLLoader {
66    model_id: String,
67    config: GGMLSpecificConfig,
68    quantized_model_id: Option<String>,
69    quantized_filename: Option<String>,
70    xlora_model_id: Option<String>,
71    xlora_order: Option<Ordering>,
72    no_kv_cache: bool,
73    chat_template: Option<String>,
74    tokenizer_json: Option<String>,
75    kind: ModelKind,
76    tgt_non_granular_index: Option<usize>,
77    jinja_explicit: Option<String>,
78    lora_adapter_ids: Option<Vec<String>>,
79}
80
81#[derive(Clone, Default)]
82/// Config for a GGML loader.
83pub struct GGMLSpecificConfig {
84    pub gqa: usize,
85    pub topology: Option<Topology>,
86}
87
88#[derive(Default)]
89/// A builder for a GGML loader.
90pub struct GGMLLoaderBuilder {
91    model_id: Option<String>,
92    config: GGMLSpecificConfig,
93    quantized_model_id: String,
94    quantized_filename: String,
95    xlora_model_id: Option<String>,
96    kind: ModelKind,
97    xlora_order: Option<Ordering>,
98    no_kv_cache: bool,
99    chat_template: Option<String>,
100    tokenizer_json: Option<String>,
101    tgt_non_granular_index: Option<usize>,
102    jinja_explicit: Option<String>,
103}
104
105impl GGMLLoaderBuilder {
106    #[allow(clippy::too_many_arguments)]
107    pub fn new(
108        config: GGMLSpecificConfig,
109        chat_template: Option<String>,
110        tokenizer_json: Option<String>,
111        model_id: Option<String>,
112        quantized_model_id: String,
113        quantized_filename: String,
114        no_kv_cache: bool,
115        jinja_explicit: Option<String>,
116    ) -> Self {
117        let kind = ModelKind::GgufQuantized {
118            quant: QuantizationKind::Ggml,
119        };
120
121        Self {
122            config,
123            chat_template,
124            tokenizer_json,
125            model_id,
126            kind,
127            quantized_filename,
128            quantized_model_id,
129            no_kv_cache,
130            jinja_explicit,
131            ..Default::default()
132        }
133    }
134
135    fn with_adapter(
136        mut self,
137        xlora_model_id: String,
138        xlora_order: Ordering,
139        no_kv_cache: bool,
140        tgt_non_granular_index: Option<usize>,
141    ) -> Self {
142        self.xlora_model_id = Some(xlora_model_id);
143        self.xlora_order = Some(xlora_order);
144        self.no_kv_cache = no_kv_cache;
145        self.tgt_non_granular_index = tgt_non_granular_index;
146        self.model_id = if let Some(id) = self.model_id {
147            Some(id)
148        } else {
149            info!(
150                "Using adapter base model ID: `{}`",
151                self.xlora_order.as_ref().unwrap().base_model_id
152            );
153            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
154        };
155        self
156    }
157
158    pub fn with_xlora(
159        mut self,
160        xlora_model_id: String,
161        xlora_order: Ordering,
162        no_kv_cache: bool,
163        tgt_non_granular_index: Option<usize>,
164    ) -> Self {
165        self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
166
167        self.with_adapter(
168            xlora_model_id,
169            xlora_order,
170            no_kv_cache,
171            tgt_non_granular_index,
172        )
173    }
174
175    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
176        self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
177
178        self.with_adapter(lora_model_id, lora_order, false, None)
179    }
180
181    pub fn build(self) -> Box<dyn Loader> {
182        Box::new(GGMLLoader {
183            model_id: self.model_id.unwrap(),
184            config: self.config,
185            xlora_model_id: self.xlora_model_id,
186            kind: self.kind,
187            xlora_order: self.xlora_order,
188            no_kv_cache: self.no_kv_cache,
189            chat_template: self.chat_template,
190            tokenizer_json: self.tokenizer_json,
191            tgt_non_granular_index: self.tgt_non_granular_index,
192            quantized_filename: Some(self.quantized_filename),
193            quantized_model_id: Some(self.quantized_model_id),
194            jinja_explicit: self.jinja_explicit,
195            lora_adapter_ids: None,
196        })
197    }
198}
199
200impl GGMLLoader {
201    #[allow(clippy::too_many_arguments)]
202    pub fn new(
203        model_id: Option<String>,
204        config: GGMLSpecificConfig,
205        quantized_model_id: Option<String>,
206        quantized_filename: Option<String>,
207        xlora_model_id: Option<String>,
208        kind: ModelKind,
209        xlora_order: Option<Ordering>,
210        no_kv_cache: bool,
211        chat_template: Option<String>,
212        tokenizer_json: Option<String>,
213        tgt_non_granular_index: Option<usize>,
214        jinja_explicit: Option<String>,
215    ) -> Self {
216        let model_id = if let Some(id) = model_id {
217            id
218        } else {
219            info!(
220                "Using adapter base model ID: `{}`",
221                xlora_order.as_ref().unwrap().base_model_id
222            );
223            xlora_order.as_ref().unwrap().base_model_id.clone()
224        };
225        Self {
226            model_id,
227            config,
228            quantized_model_id,
229            quantized_filename,
230            xlora_model_id,
231            xlora_order,
232            no_kv_cache,
233            chat_template,
234            tokenizer_json,
235            kind,
236            tgt_non_granular_index,
237            jinja_explicit,
238            lora_adapter_ids: None,
239        }
240    }
241}
242
243impl Loader for GGMLLoader {
244    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
245    fn load_model_from_path(
246        &self,
247        paths: &Box<dyn ModelPaths>,
248        dtype: &dyn TryIntoDType,
249        device: &Device,
250        silent: bool,
251        mapper: DeviceMapSetting,
252        in_situ_quant: Option<IsqType>,
253        mut paged_attn_config: Option<PagedAttentionConfig>,
254    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
255        let _progress_guard = ProgressScopeGuard::new(silent);
256        if in_situ_quant.is_some() {
257            anyhow::bail!(
258                "You are trying to in-situ quantize a GGML model. This will not do anything."
259            );
260        }
261
262        if matches!(mapper, DeviceMapSetting::Map(_)) {
263            anyhow::bail!("Device mapping is not supported for diffusion models.")
264        }
265
266        if paged_attn_config.is_some() {
267            warn!("PagedAttention is not supported for GGML models, disabling it.");
268
269            paged_attn_config = None;
270        }
271
272        debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
273
274        info!(
275            "Loading model `{}` on {}.",
276            self.get_id(),
277            device.device_pretty_repr()
278        );
279
280        #[cfg(feature = "cuda")]
281        if let Device::Cuda(dev) = &device {
282            unsafe { dev.disable_event_tracking() };
283        }
284
285        let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
286        let model = ggml_file::Content::read(&mut file, device)
287            .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
288
289        trace!("Model config: {:?}", model.hparams);
290
291        if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
292            let mut tensors = Vec::new();
293            for (name, t) in &model.tensors {
294                tensors.push(format!(
295                    "name = `{name}`, shape = {:?}, dtype = {:?}",
296                    t.shape().clone(),
297                    t.dtype(),
298                ));
299            }
300            fs::write(
301                "hanzo_ggml_tensors.txt",
302                serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
303            )?;
304
305            info!("Debug is enabled, wrote the names and information about each tensor to `hanzo_ggml_tensors.txt`.");
306        }
307
308        let _ = if paged_attn_config.is_none() {
309            warn!("GGML does not currently support PagedAttention, running without");
310            None
311        } else {
312            paged_attn_config
313        };
314
315        let has_adapter = self.kind.is_adapted();
316        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
317        let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
318
319        let model_config = {
320            // Base config (quantization only):
321            let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
322
323            // With optional adapter config:
324            let mut adapter = None;
325            if has_adapter {
326                adapter.replace(ModelConfig::Adapter::try_new(
327                    paths, device, silent, is_xlora,
328                )?);
329            }
330
331            ModelConfig::ModelParams::new(quant, adapter)
332        };
333
334        // Config into model:
335        // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
336        let model = match self.kind {
337            ModelKind::GgufQuantized { .. } => {
338                Model::Llama(Box::new(QLlama::try_from(model_config)?))
339            }
340            ModelKind::GgufAdapter { .. } => {
341                Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
342            }
343            _ => unreachable!(),
344        };
345
346        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
347        let gen_conf: Option<GenerationConfig> = paths
348            .get_gen_conf_filename()
349            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
350        let chat_template_explicit = paths
351            .get_chat_template_explicit()
352            .as_ref()
353            .map(|x| x.to_string_lossy().to_string());
354        let chat_template = get_chat_template(
355            paths,
356            self.jinja_explicit.as_ref(),
357            chat_template_explicit.as_ref(),
358            self.chat_template.as_ref(),
359            None,
360        );
361
362        let max_seq_len = match model {
363            Model::Llama(ref l) => l.max_seq_len,
364            Model::XLoraLlama(ref xl) => xl.max_seq_len,
365        };
366        let llg_factory = build_llg_factory(tokenizer.clone())?;
367        let num_hidden_layers = match model {
368            Model::Llama(ref model) => model.cache.normal().0.len(),
369            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
370        };
371        let generation_defaults = gen_conf
372            .as_ref()
373            .and_then(GenerationConfig::generation_defaults);
374        let eos = calculate_eos_tokens(&chat_template, gen_conf.as_ref(), &tokenizer);
375        Ok(Arc::new(Mutex::new(GGMLPipeline {
376            model,
377            tokenizer: tokenizer.into(),
378            no_kv_cache: self.no_kv_cache,
379            chat_template: Arc::new(chat_template),
380            model_id: self.model_id.clone(),
381            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
382                NonGranularState {
383                    non_granular_index: Arc::new(Mutex::new(0)),
384                    tgt_non_granular_index,
385                }
386            }),
387            metadata: Arc::new(GeneralMetadata {
388                max_seq_len,
389                llg_factory: Some(llg_factory),
390                no_kv_cache: self.no_kv_cache,
391                no_prefix_cache: false,
392                num_hidden_layers,
393                eos_tok: eos,
394                kind: self.kind.clone(),
395                is_xlora,
396                activation_dtype: internal_dtype,
397                sliding_window: None,
398                cache_config: None,
399                cache_engine: None,
400                model_metadata: None,
401                modalities: Modalities {
402                    input: vec![SupportedModality::Text],
403                    output: vec![SupportedModality::Text],
404                },
405            }),
406            generation_defaults,
407        })))
408    }
409
410    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
411    fn load_model_from_hf(
412        &self,
413        revision: Option<String>,
414        token_source: TokenSource,
415        dtype: &dyn TryIntoDType,
416        device: &Device,
417        silent: bool,
418        mapper: DeviceMapSetting,
419        in_situ_quant: Option<IsqType>,
420        paged_attn_config: Option<PagedAttentionConfig>,
421    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
422        let _progress_guard = ProgressScopeGuard::new(silent);
423        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
424            LocalModelPaths,
425            &token_source,
426            revision,
427            self,
428            self.quantized_model_id,
429            Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
430            silent,
431            false // Never loading UQFF
432        );
433        self.load_model_from_path(
434            &paths?,
435            dtype,
436            device,
437            silent,
438            mapper,
439            in_situ_quant,
440            paged_attn_config,
441        )
442    }
443
444    fn get_id(&self) -> String {
445        self.xlora_model_id
446            .as_deref()
447            .unwrap_or(&self.model_id)
448            .to_string()
449    }
450
451    fn get_kind(&self) -> ModelKind {
452        self.kind.clone()
453    }
454}
455
456impl PreProcessingMixin for GGMLPipeline {
457    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
458        Some(self.chat_template.clone())
459    }
460    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
461        None
462    }
463}
464
465impl IsqPipelineMixin for GGMLPipeline {
466    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
467        anyhow::bail!(
468            "You are trying to in-situ requantize a GGML model. This will not do anything."
469        )
470    }
471}
472
473impl CacheManagerMixin for GGMLPipeline {
474    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
475        FullCacheManager.clone_in_cache(self, seqs, false)
476    }
477    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
478        FullCacheManager.clone_out_cache(self, seqs, false)
479    }
480    fn set_none_cache(
481        &self,
482        seqs: &mut [&mut Sequence],
483        reset_non_granular: bool,
484        modify_draft_cache: bool,
485
486        load_preallocated_cache: bool,
487    ) {
488        FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
489        if reset_non_granular {
490            self.reset_non_granular_state()
491        }
492    }
493    fn cache(&self) -> &EitherCache {
494        match self.model {
495            Model::Llama(ref model) => &model.cache,
496            Model::XLoraLlama(ref model) => &model.cache,
497        }
498    }
499}
500
501impl MetadataMixin for GGMLPipeline {
502    fn device(&self) -> Device {
503        match self.model {
504            Model::Llama(ref model) => model.device.clone(),
505            Model::XLoraLlama(ref model) => model.device.clone(),
506        }
507    }
508    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
509        Some(self.tokenizer.clone())
510    }
511    fn name(&self) -> String {
512        self.model_id.clone()
513    }
514    fn reset_non_granular_state(&self) {
515        if let Some(s) = self.non_granular_state.as_ref() {
516            *self.cache().full().get_scalings_cache() = None;
517            *get_mut_arcmutex!(s.non_granular_index) = 0;
518        }
519    }
520    fn get_metadata(&self) -> Arc<GeneralMetadata> {
521        self.metadata.clone()
522    }
523    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
524        self.generation_defaults.clone()
525    }
526    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
527        None
528    }
529}
530
531#[async_trait::async_trait]
532impl Pipeline for GGMLPipeline {
533    fn forward_inputs(
534        &mut self,
535        inputs: Box<dyn Any>,
536        return_raw_logits: bool,
537    ) -> Result<ForwardInputsResult, hanzo_ml::Error> {
538        let ModelInputs {
539            input_ids,
540            input_ids_full,
541            seqlen_offsets,
542            seqlen_offsets_full,
543            context_lens,
544            position_ids: _,    // NOTE(hanzoai): ignore, it is for phi3
545            paged_attn_meta: _, // NOTE(hanzoai): ignore it for ggml
546            flash_meta,         // NOTE(hanzoai): ignore it for ggml dequant into f32
547            flash_meta_full,    // NOTE(hanzoai): ignore it for ggml dequant into f32
548        } = *inputs.downcast().expect("Downcast failed.");
549        let logits = match self.model {
550            Model::Llama(ref model) => {
551                model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
552            }
553            Model::XLoraLlama(ref model) => model.forward(
554                &input_ids,
555                input_ids_full.as_ref().unwrap_or(&input_ids),
556                &seqlen_offsets,
557                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
558                self.no_kv_cache,
559                &self.non_granular_state,
560                context_lens,
561                &flash_meta,
562                flash_meta_full.as_ref().unwrap_or(&flash_meta),
563            )?,
564        };
565        if return_raw_logits {
566            Ok(ForwardInputsResult::RawLogits { logits })
567        } else {
568            Ok(ForwardInputsResult::CausalGeneration { logits })
569        }
570    }
571    async fn sample_causal_gen(
572        &self,
573        seqs: &mut [&mut Sequence],
574        logits: Vec<Tensor>,
575        prefix_cacher: &mut PrefixCacheManagerV2,
576        disable_eos_stop: bool,
577        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
578    ) -> Result<(), hanzo_ml::Error> {
579        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
580    }
581    fn category(&self) -> ModelCategory {
582        ModelCategory::Text
583    }
584}
585
586// TODO
587impl AnyMoePipelineMixin for GGMLPipeline {}