Skip to main content

hanzo_engine/pipeline/
embedding.rs

1use super::isq::{UqffFullSer, WeightLoadingMode, WeightLoadingState};
2use super::{
3    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5    ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, use_ring, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29    progress::{new_multi_progress, ProgressScopeGuard},
30    tokens::get_token,
31    varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36    get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
37    GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use hanzo_ml::{Device, Tensor};
42use hanzo_nn::{Linear, Module};
43use hanzo_quant::log::once_log_info;
44use hanzo_quant::safetensors::MmapedSafetensors;
45use hanzo_quant::{
46    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
47};
48use hf_hub::Cache;
49use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{debug, info, trace, warn};
60
61pub struct EmbeddingPipeline {
62    model: Box<dyn EmbeddingModel + Send + Sync>,
63    tokenizer: Arc<Tokenizer>,
64    model_id: String,
65    metadata: Arc<GeneralMetadata>,
66    topology: Option<Topology>,
67    silent: bool,
68    config: String,
69    modules_ser: String,
70    modules_manifest: Vec<EmbeddingModulePaths>,
71    mapper: Box<dyn DeviceMapper + Send + Sync>,
72    modules: Vec<Box<dyn Module + Send + Sync>>,
73    processor: Arc<dyn Processor + Send + Sync>,
74}
75
76/// A loader for an embedding (non-quantized) model.
77pub struct EmbeddingLoader {
78    inner: Box<dyn EmbeddingModelLoader>,
79    model_id: String,
80    config: EmbeddingSpecificConfig,
81    kind: ModelKind,
82    tokenizer_json: Option<String>,
83    token_source: RwLock<Option<TokenSource>>,
84    revision: RwLock<Option<String>>,
85    from_uqff: RwLock<Option<Vec<PathBuf>>>,
86    hf_cache_path: Option<PathBuf>,
87    lora_adapter_ids: Option<Vec<String>>,
88    load_context: EmbeddingLoadContext,
89}
90
91#[derive(Clone, Copy, Default)]
92pub(crate) enum EmbeddingLoadContext {
93    #[default]
94    Primary,
95    Search,
96}
97
98impl EmbeddingLoadContext {
99    fn weight_target(self) -> &'static str {
100        match self {
101            Self::Primary => "model",
102            Self::Search => "search embedding model",
103        }
104    }
105}
106
107#[derive(Default)]
108/// A builder for a loader for an embedding (non-quantized) model.
109pub struct EmbeddingLoaderBuilder {
110    model_id: Option<String>,
111    config: EmbeddingSpecificConfig,
112    kind: ModelKind,
113    tokenizer_json: Option<String>,
114    hf_cache_path: Option<PathBuf>,
115    lora_adapter_ids: Option<Vec<String>>,
116    load_context: EmbeddingLoadContext,
117}
118
119#[derive(Clone, Default)]
120/// Config specific to loading an embedding model.
121pub struct EmbeddingSpecificConfig {
122    pub topology: Option<Topology>,
123    pub write_uqff: Option<PathBuf>,
124    pub from_uqff: Option<Vec<PathBuf>>,
125    pub hf_cache_path: Option<PathBuf>,
126}
127
128impl EmbeddingLoaderBuilder {
129    pub fn new(
130        config: EmbeddingSpecificConfig,
131        tokenizer_json: Option<String>,
132        model_id: Option<String>,
133    ) -> Self {
134        Self {
135            config,
136            tokenizer_json,
137            model_id,
138            kind: ModelKind::Normal,
139            hf_cache_path: None,
140            ..Default::default()
141        }
142    }
143
144    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
145        self.hf_cache_path = Some(hf_cache_path);
146        self
147    }
148
149    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
150        self.kind = ModelKind::Adapter {
151            adapter: AdapterKind::Lora,
152        };
153        self.lora_adapter_ids = Some(lora_adapter_ids);
154        self
155    }
156
157    pub(crate) fn with_load_context(mut self, load_context: EmbeddingLoadContext) -> Self {
158        self.load_context = load_context;
159        self
160    }
161
162    pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
163        let loader: Box<dyn EmbeddingModelLoader> = match loader {
164            Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
165            Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
166            None => Box::new(AutoEmbeddingLoader),
167        };
168        Box::new(EmbeddingLoader {
169            inner: loader,
170            model_id: self.model_id.unwrap(),
171            config: self.config,
172            kind: self.kind,
173            tokenizer_json: self.tokenizer_json,
174            token_source: RwLock::new(None),
175            revision: RwLock::new(None),
176            from_uqff: RwLock::new(None),
177            hf_cache_path: self.hf_cache_path,
178            lora_adapter_ids: self.lora_adapter_ids,
179            load_context: self.load_context,
180        })
181    }
182}
183
184impl Loader for EmbeddingLoader {
185    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
186    fn load_model_from_hf(
187        &self,
188        revision: Option<String>,
189        token_source: TokenSource,
190        dtype: &dyn TryIntoDType,
191        device: &Device,
192        silent: bool,
193        mapper: DeviceMapSetting,
194        in_situ_quant: Option<IsqType>,
195        paged_attn_config: Option<PagedAttentionConfig>,
196    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
197        let _progress_guard = ProgressScopeGuard::new(silent);
198        let cache = self
199            .hf_cache_path
200            .clone()
201            .map(Cache::new)
202            .unwrap_or_default();
203        GLOBAL_HF_CACHE.get_or_init(|| cache);
204
205        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
206            EmbeddingModelPaths,
207            &token_source,
208            revision.clone(),
209            self,
210            None,
211            None,
212            silent,
213            self.config.from_uqff.is_some()
214        );
215        *self
216            .token_source
217            .write()
218            .expect("Failed to write to token source") = Some(token_source);
219        *self.revision.write().expect("Failed to write to revision") = revision.clone();
220        if let Some(from_uqff) = self.config.from_uqff.clone() {
221            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
222        }
223        self.load_model_from_path(
224            &paths?,
225            dtype,
226            device,
227            silent,
228            mapper,
229            in_situ_quant,
230            paged_attn_config,
231        )
232    }
233
234    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
235    fn load_model_from_path(
236        &self,
237        paths: &Box<dyn ModelPaths>,
238        dtype: &dyn TryIntoDType,
239        device: &Device,
240        silent: bool,
241        mut mapper: DeviceMapSetting,
242        in_situ_quant: Option<IsqType>,
243        mut paged_attn_config: Option<PagedAttentionConfig>,
244    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
245        let _progress_guard = ProgressScopeGuard::new(silent);
246        let config = std::fs::read_to_string(paths.get_config_filename())?;
247
248        if paged_attn_config.is_some() {
249            warn!("PagedAttention is not supported for embedding models, disabling it.");
250            paged_attn_config = None;
251        }
252
253        debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
254
255        let use_nccl = hanzo_quant::distributed::use_nccl();
256
257        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
258            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
259            let WorkerTransferData::Init { id: _, worker_rank } = payload;
260            vec![hanzo_ml::Device::new_cuda_with_stream(worker_rank + 1)?]
261        } else if use_nccl || use_ring() {
262            vec![hanzo_ml::Device::new_cuda_with_stream(0)?]
263        } else {
264            device_map::get_all_similar_devices(device)?
265        };
266        #[cfg(feature = "cuda")]
267        for device in &available_devices {
268            if let Device::Cuda(dev) = device {
269                unsafe { dev.disable_event_tracking() };
270            }
271        }
272        let device = if use_nccl || use_ring() {
273            available_devices[0].clone()
274        } else {
275            device.clone()
276        };
277
278        // If auto, convert to Map if not using nccl
279        if use_nccl || use_ring() {
280            mapper = DeviceMapSetting::DummyNccl {
281                nm_device: available_devices[0].clone(),
282            };
283        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
284            // Initial dtype
285            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
286
287            // ISQ or UQFF: quantized path
288            // Match logic below where UQFF has priority
289            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
290                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
291                    let weight_pack_factor = {
292                        let ser_artifacts =
293                            unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
294                        let mut total_pack_factors = 0;
295                        let total_tensors = ser_artifacts.tensors().len();
296                        for (_, artifact) in ser_artifacts.tensors() {
297                            let artifact = artifact.data();
298                            // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
299                            let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
300                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
301                            {
302                                QuantizedSerdeType::Hqq => {
303                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
304                                        .pack_factor(dtype)
305                                }
306                                QuantizedSerdeType::Gguf => {
307                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
308                                        .pack_factor(dtype)
309                                }
310                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
311                                QuantizedSerdeType::Unquant => 1,
312                                QuantizedSerdeType::Afq => {
313                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
314                                        .pack_factor(dtype)
315                                }
316                                QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
317                                QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
318                            };
319                            total_pack_factors += pack_factor;
320                        }
321
322                        total_pack_factors / total_tensors
323                    };
324
325                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
326                        &config,
327                        dtype,
328                        weight_pack_factor,
329                        None,
330                    )?;
331                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
332                        &config,
333                        dtype,
334                        weight_pack_factor,
335                        None,
336                    )?;
337                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338                    (
339                        layer_sizes_in_bytes,
340                        non_mapped_size_in_bytes,
341                        layer_sizes_sum + non_mapped_size_in_bytes,
342                    )
343                } else if let Some(isq) = in_situ_quant {
344                    let weight_pack_factor = isq.pack_factor(dtype);
345                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
346                        &config,
347                        dtype,
348                        weight_pack_factor,
349                        None,
350                    )?;
351                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
352                        &config,
353                        dtype,
354                        weight_pack_factor,
355                        None,
356                    )?;
357                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
358                    (
359                        layer_sizes_in_bytes,
360                        non_mapped_size_in_bytes,
361                        layer_sizes_sum + non_mapped_size_in_bytes,
362                    )
363                } else {
364                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
365                    let weight_pack_factor =
366                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
367                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
368                        &config,
369                        dtype,
370                        weight_pack_factor,
371                        None,
372                    )?;
373                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
374                        &config,
375                        dtype,
376                        weight_pack_factor,
377                        None,
378                    )?;
379                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
380                    (
381                        layer_sizes_in_bytes,
382                        non_mapped_size_in_bytes,
383                        layer_sizes_sum + non_mapped_size_in_bytes,
384                    )
385                };
386
387            let new = auto_device_map::get_device_layers(
388                &*self.inner,
389                &config,
390                self.inner.num_layers(&config)?,
391                layer_sizes_in_bytes,
392                non_mapped_size_in_bytes,
393                total_model_size_in_bytes,
394                &available_devices,
395                dtype,
396                &params,
397                paged_attn_config.as_ref(),
398            )?;
399            mapper = DeviceMapSetting::Map(new);
400        }
401
402        let pipeline_mapper = mapper.into_mapper(
403            self.inner.num_layers(&config)?,
404            &device,
405            self.config.topology.as_ref(),
406            &available_devices,
407        )?;
408        let mapper = mapper.into_mapper(
409            self.inner.num_layers(&config)?,
410            &device,
411            self.config.topology.as_ref(),
412            &available_devices,
413        )?;
414        let mut layer_devices = Vec::new();
415        for layer in 0..self.inner.num_layers(&config)? {
416            let device = mapper.device_for(layer, false).cloned();
417            layer_devices.push(device);
418        }
419        let dtype = mapper.get_min_dtype(dtype)?;
420
421        trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
422        if crate::using_flash_attn() {
423            once_log_info("FlashAttention is enabled.");
424        }
425
426        let topology_overrides = self
427            .config
428            .topology
429            .as_ref()
430            .map(|topology| {
431                topology
432                    .pattern_overrides()
433                    .into_iter()
434                    .map(|(regex, layer)| ImmediateIsqOverride {
435                        predicate: regex,
436                        ty: layer.isq,
437                        device: layer.device.clone(),
438                    })
439                    .collect::<Vec<_>>()
440            })
441            .unwrap_or_default();
442        let has_override_isq = topology_overrides
443            .iter()
444            .any(|override_entry| override_entry.ty.is_some());
445        let topology_requires_post_quant = self
446            .config
447            .topology
448            .as_ref()
449            .is_some_and(|topology| topology.requires_post_quantization());
450
451        let allow_immediate_cli = in_situ_quant.is_some();
452
453        let mut immediate_ty = None;
454        let mut immediate_predicates = Vec::new();
455        if allow_immediate_cli {
456            immediate_ty = in_situ_quant;
457            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
458            info!("Applying ISQ to {in_situ_quant:?}");
459            if immediate_predicates.is_empty() {
460                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
461            }
462        }
463
464        let use_immediate = allow_immediate_cli || has_override_isq;
465        if use_immediate {
466            let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
467            info!("Applying immediate ISQ in parallel on {num_threads} threads.");
468            hanzo_quant::set_immediate_isq_with_pool(
469                immediate_ty,
470                immediate_predicates.clone(),
471                topology_overrides.clone(),
472                pool,
473            );
474        }
475
476        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
477        let mut loading_isq = if use_immediate {
478            false
479        } else {
480            in_situ_quant.is_some()
481        };
482        loading_isq |= topology_requires_post_quant;
483        loading_isq |= self.config.from_uqff.is_some();
484
485        // Load onto the regular device if not using isq.
486        // For immediate ISQ on discrete GPUs, load to CPU: the mapper will set the correct target
487        // device per-layer, and linear constructors will override to CPU for ISQ-targeted weights.
488        // On integrated/unified memory systems (e.g. Grace Blackwell), CPU and GPU share memory,
489        // so we load directly to the device.
490        let load_device = if !loading_isq {
491            loading_isq = false;
492            if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
493                Device::Cpu
494            } else {
495                device.clone()
496            }
497        } else {
498            Device::Cpu
499        };
500
501        let attention_mechanism = if paged_attn_config.is_some() {
502            AttentionImplementation::PagedAttention
503        } else {
504            AttentionImplementation::Eager
505        };
506
507        let multi_progress = Arc::new(new_multi_progress());
508
509        let modules_config: Vec<_> = paths
510            .get_modules()
511            .context("Embedding models require the `modules.json` file.")?
512            .to_vec();
513        assert!(matches!(
514            modules_config.first(),
515            Some(EmbeddingModulePaths::Transformer { .. })
516        ));
517
518        let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
519        for module in &modules_config {
520            match module {
521                EmbeddingModulePaths::Transformer { .. } => (),
522                EmbeddingModulePaths::Pooling { config, .. } => {
523                    let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
524                    modules.push(Box::new(layer));
525                }
526                EmbeddingModulePaths::Dense { config, model, .. } => {
527                    let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
528                    let safetensors = unsafe { MmapedSafetensors::new(model)? };
529                    let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
530                    let bias = if config.bias {
531                        Some(safetensors.load("linear.bias", &device, Some(dtype))?)
532                    } else {
533                        None
534                    };
535                    let (out_f, in_f) = weight.dims2()?;
536                    assert_eq!((out_f, in_f), (config.out_features, config.in_features));
537                    if !matches!(config.activation_function, DenseActivation::Identity) {
538                        anyhow::bail!("Expected Identity activation function.");
539                    }
540
541                    modules.push(Box::new(Linear::new(weight, bias)));
542                }
543                EmbeddingModulePaths::Normalize { .. } => {
544                    modules.push(Box::new(Normalize));
545                }
546            }
547        }
548        let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
549
550        info!(
551            "{}",
552            WeightLoadingMode::from(WeightLoadingState {
553                from_uqff: self.config.from_uqff.is_some(),
554                loading_isq,
555                immediate_isq: use_immediate,
556                write_uqff: self.config.write_uqff.is_some(),
557            })
558            .message(self.load_context.weight_target())
559        );
560
561        let mut model = if use_nccl || use_ring() {
562            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
563                dtype,
564                &device,
565                &available_devices,
566                silent,
567                &config,
568                loading_isq,
569                self.config.from_uqff.is_some(),
570                IsqOrganization::Default,
571                &*self.inner,
572                paths.as_ref(),
573            )?;
574
575            // Special case for where things can be more optimially loaded.
576            match self.kind {
577                ModelKind::Normal => embedding_normal_model_loader_sharded!(
578                    sharded_vb,
579                    config,
580                    self.inner,
581                    mapper,
582                    loading_isq,
583                    device.clone(),
584                    attention_mechanism,
585                    multi_progress.clone(),
586                ),
587                _ => unreachable!(),
588            }
589        } else {
590            match self.kind {
591                ModelKind::Normal => embedding_normal_model_loader!(
592                    paths,
593                    Some(dtype),
594                    &load_device,
595                    layer_devices.clone(),
596                    config,
597                    self.inner,
598                    silent,
599                    mapper,
600                    loading_isq,
601                    self.config.from_uqff.is_some(),
602                    device.clone(),
603                    attention_mechanism,
604                    multi_progress,
605                ),
606                _ => unreachable!(),
607            }
608        };
609
610        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
611
612        let should_serialize = self.config.write_uqff.is_some();
613        let should_quantize_pass = loading_isq;
614
615        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
616            if should_quantize_pass {
617                debug!("Applying ISQ to all ranks.");
618            } else {
619                debug!("Serializing existing ISQ tensors without additional quantization.");
620            }
621            model.quantize(
622                in_situ_quant,
623                device.clone(),
624                self.config.topology.as_ref(),
625                silent,
626                None,
627                IsqOrganization::Default,
628                should_quantize_pass,
629                self.config.write_uqff.as_ref(),
630                UqffFullSer {
631                    tokenizer: &tokenizer,
632                    template_filename: paths.get_template_filename(),
633                    generation_config: paths.get_gen_conf_filename(),
634                    config: config.clone(),
635                    processor_filename: paths.get_processor_config(),
636                    preprocessor_filename: paths.get_preprocessor_config(),
637                    modules: Some(&modules_ser),
638                    module_paths: Some(&modules_config),
639                },
640                Arc::new(new_multi_progress()),
641            )?;
642        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
643            model.load_from_artifacts(
644                device.clone(),
645                self.config.topology.as_ref(),
646                silent,
647                from_uqff,
648            )?;
649        }
650
651        let has_causal_attention = self.inner.has_causal_attention(&config)?;
652        let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
653        Ok(Arc::new(Mutex::new(EmbeddingPipeline {
654            model,
655            tokenizer: tokenizer.into(),
656            model_id: self.model_id.clone(),
657            metadata: Arc::new(GeneralMetadata {
658                max_seq_len,
659                llg_factory: None,
660                is_xlora: false,
661                no_prefix_cache: false,
662                num_hidden_layers: 1, // FIXME(hanzoai): we know this is only for caching, so its OK.
663                eos_tok: vec![],
664                kind: ModelKind::Normal,
665                no_kv_cache: true, // NOTE(hanzoai): no cache for these.
666                activation_dtype: dtype,
667                sliding_window: None,
668                cache_config: None,
669                cache_engine: None,
670                model_metadata: None,
671                modalities: Modalities {
672                    input: vec![SupportedModality::Text],
673                    output: vec![SupportedModality::Embedding],
674                },
675            }),
676            topology: self.config.topology.clone(),
677            silent,
678            config,
679            modules_ser,
680            modules_manifest: modules_config,
681            mapper: pipeline_mapper,
682            modules,
683            processor: Arc::new(EmbeddingProcessor {
684                has_causal_attention,
685            }),
686        })))
687    }
688
689    fn get_id(&self) -> String {
690        self.model_id.to_string()
691    }
692
693    fn get_kind(&self) -> ModelKind {
694        self.kind.clone()
695    }
696}
697
698impl PreProcessingMixin for EmbeddingPipeline {
699    fn get_processor(&self) -> Arc<dyn Processor> {
700        self.processor.clone()
701    }
702    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
703        None
704    }
705    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
706        None
707    }
708}
709
710impl IsqPipelineMixin for EmbeddingPipeline {
711    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
712        let device = self.device().clone();
713        self.model
714            .quantize(
715                Some(dtype),
716                device,
717                self.topology.as_ref(),
718                self.silent,
719                None,
720                IsqOrganization::Default,
721                true,
722                None,
723                UqffFullSer {
724                    tokenizer: &self.tokenizer,
725                    template_filename: &None,
726                    generation_config: None,
727                    config: self.config.clone(),
728                    processor_filename: &None,
729                    preprocessor_filename: &None,
730                    modules: Some(&self.modules_ser),
731                    module_paths: Some(&self.modules_manifest),
732                },
733                Arc::new(new_multi_progress()),
734            )
735            .map_err(anyhow::Error::msg)
736    }
737}
738
739impl CacheManagerMixin for EmbeddingPipeline {
740    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
741    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
742    fn set_none_cache(
743        &self,
744        _seqs: &mut [&mut Sequence],
745        _reset_non_granular: bool,
746        _modify_draft_cache: bool,
747        _load_preallocated_cache: bool,
748    ) {
749    }
750    fn cache(&self) -> &EitherCache {
751        unreachable!()
752    }
753}
754
755impl MetadataMixin for EmbeddingPipeline {
756    fn device(&self) -> Device {
757        self.model.device().clone()
758    }
759    fn get_metadata(&self) -> Arc<GeneralMetadata> {
760        self.metadata.clone()
761    }
762    fn name(&self) -> String {
763        self.model_id.clone()
764    }
765    fn reset_non_granular_state(&self) {}
766    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
767        Some(self.tokenizer.clone())
768    }
769    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
770        Some(&*self.mapper)
771    }
772}
773
774#[async_trait::async_trait]
775impl Pipeline for EmbeddingPipeline {
776    fn forward_inputs(
777        &mut self,
778        inputs: Box<dyn Any>,
779        _return_raw_logits: bool,
780    ) -> hanzo_ml::Result<ForwardInputsResult> {
781        let ModelInputs {
782            input_ids,
783            flash_meta,
784        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
785
786        let mut xs = self.model.forward(&input_ids, &flash_meta)?;
787        for module in &self.modules {
788            xs = module.forward(&xs)?;
789        }
790
791        Ok(ForwardInputsResult::Embeddings { embeddings: xs })
792    }
793    async fn sample_causal_gen(
794        &self,
795        seqs: &mut [&mut Sequence],
796        logits: Vec<Tensor>,
797        prefix_cacher: &mut PrefixCacheManagerV2,
798        disable_eos_stop: bool,
799        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
800    ) -> Result<(), hanzo_ml::Error> {
801        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
802    }
803    fn category(&self) -> ModelCategory {
804        ModelCategory::Embedding
805    }
806}
807
808impl AnyMoePipelineMixin for EmbeddingPipeline {}