Skip to main content

mistralrs_core/pipeline/
embedding.rs

1use super::isq::UqffFullSer;
2use super::{
3    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5    ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29    progress::{new_multi_progress, ProgressScopeGuard},
30    tokens::get_token,
31    varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36    api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
37    TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use candle_core::{Device, Tensor};
42use candle_nn::{Linear, Module};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::safetensors::MmapedSafetensors;
47use mistralrs_quant::{
48    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
49};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct EmbeddingPipeline {
62    model: Box<dyn EmbeddingModel + Send + Sync>,
63    tokenizer: Arc<Tokenizer>,
64    model_id: String,
65    metadata: Arc<GeneralMetadata>,
66    topology: Option<Topology>,
67    silent: bool,
68    config: String,
69    modules_ser: String,
70    modules_manifest: Vec<EmbeddingModulePaths>,
71    mapper: Box<dyn DeviceMapper + Send + Sync>,
72    modules: Vec<Box<dyn Module + Send + Sync>>,
73    processor: Arc<dyn Processor + Send + Sync>,
74}
75
76/// A loader for a vision (non-quantized) model.
77pub struct EmbeddingLoader {
78    inner: Box<dyn EmbeddingModelLoader>,
79    model_id: String,
80    config: EmbeddingSpecificConfig,
81    kind: ModelKind,
82    tokenizer_json: Option<String>,
83    token_source: RwLock<Option<TokenSource>>,
84    revision: RwLock<Option<String>>,
85    from_uqff: RwLock<Option<Vec<PathBuf>>>,
86    hf_cache_path: Option<PathBuf>,
87    lora_adapter_ids: Option<Vec<String>>,
88}
89
90#[derive(Default)]
91/// A builder for a loader for a vision (non-quantized) model.
92pub struct EmbeddingLoaderBuilder {
93    model_id: Option<String>,
94    config: EmbeddingSpecificConfig,
95    kind: ModelKind,
96    tokenizer_json: Option<String>,
97    hf_cache_path: Option<PathBuf>,
98    lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102/// Config specific to loading a vision model.
103pub struct EmbeddingSpecificConfig {
104    pub topology: Option<Topology>,
105    pub write_uqff: Option<PathBuf>,
106    pub from_uqff: Option<Vec<PathBuf>>,
107    pub hf_cache_path: Option<PathBuf>,
108}
109
110impl EmbeddingLoaderBuilder {
111    pub fn new(
112        config: EmbeddingSpecificConfig,
113        tokenizer_json: Option<String>,
114        model_id: Option<String>,
115    ) -> Self {
116        Self {
117            config,
118            tokenizer_json,
119            model_id,
120            kind: ModelKind::Normal,
121            hf_cache_path: None,
122            ..Default::default()
123        }
124    }
125
126    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
127        self.hf_cache_path = Some(hf_cache_path);
128        self
129    }
130
131    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
132        self.kind = ModelKind::Adapter {
133            adapter: AdapterKind::Lora,
134        };
135        self.lora_adapter_ids = Some(lora_adapter_ids);
136        self
137    }
138
139    pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
140        let loader: Box<dyn EmbeddingModelLoader> = match loader {
141            Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
142            Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
143            None => Box::new(AutoEmbeddingLoader),
144        };
145        Box::new(EmbeddingLoader {
146            inner: loader,
147            model_id: self.model_id.unwrap(),
148            config: self.config,
149            kind: self.kind,
150            tokenizer_json: self.tokenizer_json,
151            token_source: RwLock::new(None),
152            revision: RwLock::new(None),
153            from_uqff: RwLock::new(None),
154            hf_cache_path: self.hf_cache_path,
155            lora_adapter_ids: self.lora_adapter_ids,
156        })
157    }
158}
159
160impl Loader for EmbeddingLoader {
161    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
162    fn load_model_from_hf(
163        &self,
164        revision: Option<String>,
165        token_source: TokenSource,
166        dtype: &dyn TryIntoDType,
167        device: &Device,
168        silent: bool,
169        mapper: DeviceMapSetting,
170        in_situ_quant: Option<IsqType>,
171        paged_attn_config: Option<PagedAttentionConfig>,
172    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
173        let _progress_guard = ProgressScopeGuard::new(silent);
174        let cache = self
175            .hf_cache_path
176            .clone()
177            .map(Cache::new)
178            .unwrap_or_default();
179        GLOBAL_HF_CACHE.get_or_init(|| cache);
180
181        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
182            EmbeddingModelPaths,
183            &token_source,
184            revision.clone(),
185            self,
186            None,
187            None,
188            silent,
189            self.config.from_uqff.is_some()
190        );
191        if let Some(from_uqff) = self.config.from_uqff.clone() {
192            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
193        }
194        *self
195            .token_source
196            .write()
197            .expect("Failed to write to token source") = Some(token_source);
198        *self.revision.write().expect("Failed to write to revision") = revision;
199        self.load_model_from_path(
200            &paths?,
201            dtype,
202            device,
203            silent,
204            mapper,
205            in_situ_quant,
206            paged_attn_config,
207        )
208    }
209
210    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
211    fn load_model_from_path(
212        &self,
213        paths: &Box<dyn ModelPaths>,
214        dtype: &dyn TryIntoDType,
215        device: &Device,
216        silent: bool,
217        mut mapper: DeviceMapSetting,
218        in_situ_quant: Option<IsqType>,
219        mut paged_attn_config: Option<PagedAttentionConfig>,
220    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
221        let _progress_guard = ProgressScopeGuard::new(silent);
222        let config = std::fs::read_to_string(paths.get_config_filename())?;
223
224        if paged_attn_config.is_some() {
225            warn!("PagedAttention is not supported for embedding models, disabling it.");
226            paged_attn_config = None;
227        }
228
229        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
230
231        let use_nccl = mistralrs_quant::distributed::use_nccl();
232
233        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
234            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
235            let WorkerTransferData::Init { id: _, worker_rank } = payload;
236            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
237        } else if use_nccl {
238            vec![candle_core::Device::new_cuda_with_stream(0)?]
239        } else {
240            device_map::get_all_similar_devices(device)?
241        };
242        #[cfg(feature = "cuda")]
243        for device in &available_devices {
244            if let Device::Cuda(dev) = device {
245                unsafe { dev.disable_event_tracking() };
246            }
247        }
248        let device = if use_nccl {
249            available_devices[0].clone()
250        } else {
251            device.clone()
252        };
253
254        // If auto, convert to Map if not using nccl
255        if use_nccl {
256            mapper = DeviceMapSetting::DummyNccl {
257                nm_device: available_devices[0].clone(),
258            };
259        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
260            // Initial dtype
261            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
262
263            // ISQ or UQFF: quantized path
264            // Match logic below where UQFF has priority
265            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
266                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
267                    let weight_pack_factor = {
268                        let ser_artifacts = unsafe {
269                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
270                        };
271                        let mut total_pack_factors = 0;
272                        let total_tensors = ser_artifacts.tensors().len();
273                        for (_, artifact) in ser_artifacts.tensors() {
274                            let artifact = artifact.data();
275                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
276                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
277                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
278                            {
279                                QuantizedSerdeType::Hqq => {
280                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
281                                        .pack_factor(dtype)
282                                }
283                                QuantizedSerdeType::Gguf => {
284                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
285                                        .pack_factor(dtype)
286                                }
287                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
288                                QuantizedSerdeType::Unquant => 1,
289                                QuantizedSerdeType::Afq => {
290                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
291                                        .pack_factor(dtype)
292                                }
293                            };
294                            total_pack_factors += pack_factor;
295                        }
296
297                        total_pack_factors / total_tensors
298                    };
299
300                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
301                        &config,
302                        dtype,
303                        weight_pack_factor,
304                        None,
305                    )?;
306                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
307                        &config,
308                        dtype,
309                        weight_pack_factor,
310                        None,
311                    )?;
312                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
313                    (
314                        layer_sizes_in_bytes,
315                        non_mapped_size_in_bytes,
316                        layer_sizes_sum + non_mapped_size_in_bytes,
317                    )
318                } else if let Some(isq) = in_situ_quant {
319                    let weight_pack_factor = isq.pack_factor(dtype);
320                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
321                        &config,
322                        dtype,
323                        weight_pack_factor,
324                        None,
325                    )?;
326                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
327                        &config,
328                        dtype,
329                        weight_pack_factor,
330                        None,
331                    )?;
332                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
333                    (
334                        layer_sizes_in_bytes,
335                        non_mapped_size_in_bytes,
336                        layer_sizes_sum + non_mapped_size_in_bytes,
337                    )
338                } else {
339                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
340                    let weight_pack_factor =
341                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
342                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
343                        &config,
344                        dtype,
345                        weight_pack_factor,
346                        None,
347                    )?;
348                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
349                        &config,
350                        dtype,
351                        weight_pack_factor,
352                        None,
353                    )?;
354                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
355                    (
356                        layer_sizes_in_bytes,
357                        non_mapped_size_in_bytes,
358                        layer_sizes_sum + non_mapped_size_in_bytes,
359                    )
360                };
361
362            let new = auto_device_map::get_device_layers(
363                &*self.inner,
364                &config,
365                self.inner.num_layers(&config)?,
366                layer_sizes_in_bytes,
367                non_mapped_size_in_bytes,
368                total_model_size_in_bytes,
369                &available_devices,
370                dtype,
371                &params,
372                paged_attn_config.as_ref(),
373            )?;
374            mapper = DeviceMapSetting::Map(new);
375        }
376
377        let pipeline_mapper = mapper.into_mapper(
378            self.inner.num_layers(&config)?,
379            &device,
380            self.config.topology.as_ref(),
381            &available_devices,
382        )?;
383        let mapper = mapper.into_mapper(
384            self.inner.num_layers(&config)?,
385            &device,
386            self.config.topology.as_ref(),
387            &available_devices,
388        )?;
389        let mut layer_devices = Vec::new();
390        for layer in 0..self.inner.num_layers(&config)? {
391            let device = mapper.device_for(layer, false).cloned();
392            layer_devices.push(device);
393        }
394        let dtype = mapper.get_min_dtype(dtype)?;
395
396        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
397        if crate::using_flash_attn() {
398            once_log_info("FlashAttention is enabled.");
399        }
400
401        let topology_overrides = self
402            .config
403            .topology
404            .as_ref()
405            .map(|topology| {
406                topology
407                    .pattern_overrides()
408                    .into_iter()
409                    .map(|(regex, layer)| ImmediateIsqOverride {
410                        predicate: regex,
411                        ty: layer.isq,
412                        device: layer.device.clone(),
413                    })
414                    .collect::<Vec<_>>()
415            })
416            .unwrap_or_default();
417        let has_override_isq = topology_overrides
418            .iter()
419            .any(|override_entry| override_entry.ty.is_some());
420        let topology_requires_post_quant = self
421            .config
422            .topology
423            .as_ref()
424            .is_some_and(|topology| topology.requires_post_quantization());
425
426        let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
427
428        let mut immediate_ty = None;
429        let mut immediate_predicates = Vec::new();
430        if allow_immediate_cli {
431            immediate_ty = in_situ_quant;
432            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
433            info!("Applying ISQ to {in_situ_quant:?}");
434            if immediate_predicates.is_empty() {
435                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
436            }
437        }
438
439        let use_immediate = allow_immediate_cli || has_override_isq;
440        if use_immediate {
441            mistralrs_quant::set_immediate_isq_with_overrides(
442                immediate_ty,
443                immediate_predicates.clone(),
444                topology_overrides.clone(),
445            );
446        }
447
448        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
449        let mut loading_isq = if use_immediate {
450            false
451        } else {
452            in_situ_quant.is_some()
453        };
454        loading_isq |= topology_requires_post_quant;
455
456        // Load onto the regular device if not using isq
457        let load_device = if !loading_isq {
458            loading_isq = false;
459            device.clone()
460        } else {
461            Device::Cpu
462        };
463
464        let attention_mechanism = if paged_attn_config.is_some() {
465            AttentionImplementation::PagedAttention
466        } else {
467            AttentionImplementation::Eager
468        };
469
470        let multi_progress = Arc::new(new_multi_progress());
471
472        let modules_config: Vec<_> = paths
473            .get_modules()
474            .context("Embedding models require the `modules.json` file.")?
475            .to_vec();
476        assert!(matches!(
477            modules_config.first(),
478            Some(EmbeddingModulePaths::Transformer { .. })
479        ));
480
481        let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
482        for module in &modules_config {
483            match module {
484                EmbeddingModulePaths::Transformer { .. } => (),
485                EmbeddingModulePaths::Pooling { config, .. } => {
486                    let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
487                    modules.push(Box::new(layer));
488                }
489                EmbeddingModulePaths::Dense { config, model, .. } => {
490                    let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
491                    let safetensors = unsafe { MmapedSafetensors::new(model)? };
492                    let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
493                    let bias = if config.bias {
494                        Some(safetensors.load("linear.bias", &device, Some(dtype))?)
495                    } else {
496                        None
497                    };
498                    let (out_f, in_f) = weight.dims2()?;
499                    assert_eq!((out_f, in_f), (config.out_features, config.in_features));
500                    if !matches!(config.activation_function, DenseActivation::Identity) {
501                        anyhow::bail!("Expected Identity activation function.");
502                    }
503
504                    modules.push(Box::new(Linear::new(weight, bias)));
505                }
506                EmbeddingModulePaths::Normalize { .. } => {
507                    modules.push(Box::new(Normalize));
508                }
509            }
510        }
511        let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
512
513        let mut model = if use_nccl {
514            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
515                dtype,
516                &device,
517                &available_devices,
518                silent,
519                &config,
520                loading_isq,
521                self.config.from_uqff.is_some(),
522                IsqOrganization::Default,
523                &*self.inner,
524                paths.as_ref(),
525            )?;
526
527            // Special case for where things can be more optimially loaded.
528            match self.kind {
529                ModelKind::Normal => embedding_normal_model_loader_sharded!(
530                    sharded_vb,
531                    config,
532                    self.inner,
533                    mapper,
534                    loading_isq,
535                    device.clone(),
536                    attention_mechanism,
537                    multi_progress.clone(),
538                ),
539                _ => unreachable!(),
540            }
541        } else {
542            match self.kind {
543                ModelKind::Normal => embedding_normal_model_loader!(
544                    paths,
545                    Some(dtype),
546                    &load_device,
547                    layer_devices.clone(),
548                    config,
549                    self.inner,
550                    silent,
551                    mapper,
552                    loading_isq,
553                    self.config.from_uqff.is_some(),
554                    device.clone(),
555                    attention_mechanism,
556                    multi_progress,
557                ),
558                _ => unreachable!(),
559            }
560        };
561
562        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
563
564        let should_serialize = self.config.write_uqff.is_some();
565        let should_quantize_pass = loading_isq;
566
567        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
568            if should_quantize_pass {
569                info!("Applying ISQ to all ranks.");
570            } else {
571                info!("Serializing existing ISQ tensors without additional quantization.");
572            }
573            model.quantize(
574                in_situ_quant,
575                device.clone(),
576                self.config.topology.as_ref(),
577                silent,
578                None,
579                IsqOrganization::Default,
580                should_quantize_pass,
581                self.config.write_uqff.as_ref(),
582                UqffFullSer {
583                    tokenizer: &tokenizer,
584                    template_filename: paths.get_template_filename(),
585                    generation_config: paths.get_gen_conf_filename(),
586                    config: config.clone(),
587                    processor_filename: paths.get_processor_config(),
588                    preprocessor_filename: paths.get_preprocessor_config(),
589                    modules: Some(&modules_ser),
590                    module_paths: Some(&modules_config),
591                },
592                Arc::new(new_multi_progress()),
593            )?;
594        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
595            model.load_from_artifacts(
596                device.clone(),
597                self.config.topology.as_ref(),
598                silent,
599                from_uqff,
600            )?;
601        }
602
603        let has_causal_attention = self.inner.has_causal_attention(&config)?;
604        let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
605        Ok(Arc::new(Mutex::new(EmbeddingPipeline {
606            model,
607            tokenizer: tokenizer.into(),
608            model_id: self.model_id.clone(),
609            metadata: Arc::new(GeneralMetadata {
610                max_seq_len,
611                llg_factory: None,
612                is_xlora: false,
613                no_prefix_cache: false,
614                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
615                eos_tok: vec![],
616                kind: ModelKind::Normal,
617                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
618                activation_dtype: dtype,
619                sliding_window: None,
620                cache_config: None,
621                cache_engine: None,
622                model_metadata: None,
623                modalities: Modalities {
624                    input: vec![SupportedModality::Text],
625                    output: vec![SupportedModality::Embedding],
626                },
627            }),
628            topology: self.config.topology.clone(),
629            silent,
630            config,
631            modules_ser,
632            modules_manifest: modules_config,
633            mapper: pipeline_mapper,
634            modules,
635            processor: Arc::new(EmbeddingProcessor {
636                has_causal_attention,
637            }),
638        })))
639    }
640
641    fn get_id(&self) -> String {
642        self.model_id.to_string()
643    }
644
645    fn get_kind(&self) -> ModelKind {
646        self.kind.clone()
647    }
648}
649
650impl PreProcessingMixin for EmbeddingPipeline {
651    fn get_processor(&self) -> Arc<dyn Processor> {
652        self.processor.clone()
653    }
654    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
655        None
656    }
657    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
658        None
659    }
660}
661
662impl IsqPipelineMixin for EmbeddingPipeline {
663    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
664        let device = self.device().clone();
665        self.model
666            .quantize(
667                Some(dtype),
668                device,
669                self.topology.as_ref(),
670                self.silent,
671                None,
672                IsqOrganization::Default,
673                true,
674                None,
675                UqffFullSer {
676                    tokenizer: &self.tokenizer,
677                    template_filename: &None,
678                    generation_config: None,
679                    config: self.config.clone(),
680                    processor_filename: &None,
681                    preprocessor_filename: &None,
682                    modules: Some(&self.modules_ser),
683                    module_paths: Some(&self.modules_manifest),
684                },
685                Arc::new(new_multi_progress()),
686            )
687            .map_err(anyhow::Error::msg)
688    }
689}
690
691impl CacheManagerMixin for EmbeddingPipeline {
692    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
693    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
694    fn set_none_cache(
695        &self,
696        _seqs: &mut [&mut Sequence],
697        _reset_non_granular: bool,
698        _modify_draft_cache: bool,
699        _load_preallocated_cache: bool,
700    ) {
701    }
702    fn cache(&self) -> &EitherCache {
703        unreachable!()
704    }
705}
706
707impl MetadataMixin for EmbeddingPipeline {
708    fn device(&self) -> Device {
709        self.model.device().clone()
710    }
711    fn get_metadata(&self) -> Arc<GeneralMetadata> {
712        self.metadata.clone()
713    }
714    fn name(&self) -> String {
715        self.model_id.clone()
716    }
717    fn reset_non_granular_state(&self) {}
718    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
719        Some(self.tokenizer.clone())
720    }
721    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
722        Some(&*self.mapper)
723    }
724}
725
726#[async_trait::async_trait]
727impl Pipeline for EmbeddingPipeline {
728    fn forward_inputs(
729        &mut self,
730        inputs: Box<dyn Any>,
731        _return_raw_logits: bool,
732    ) -> candle_core::Result<ForwardInputsResult> {
733        let ModelInputs {
734            input_ids,
735            flash_meta,
736        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
737
738        let mut xs = self.model.forward(&input_ids, &flash_meta)?;
739        for module in &self.modules {
740            xs = module.forward(&xs)?;
741        }
742
743        Ok(ForwardInputsResult::Embeddings { embeddings: xs })
744    }
745    async fn sample_causal_gen(
746        &self,
747        seqs: &mut [&mut Sequence],
748        logits: Vec<Tensor>,
749        prefix_cacher: &mut PrefixCacheManagerV2,
750        disable_eos_stop: bool,
751        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
752    ) -> Result<(), candle_core::Error> {
753        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
754    }
755    fn category(&self) -> ModelCategory {
756        ModelCategory::Embedding
757    }
758}
759
760impl AnyMoePipelineMixin for EmbeddingPipeline {}