1use super::isq::UqffFullSer;
2use super::{
3 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5 ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29 progress::{new_multi_progress, ProgressScopeGuard},
30 tokens::get_token,
31 varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36 api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
37 TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use candle_core::{Device, Tensor};
42use candle_nn::{Linear, Module};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::safetensors::MmapedSafetensors;
47use mistralrs_quant::{
48 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
49};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct EmbeddingPipeline {
62 model: Box<dyn EmbeddingModel + Send + Sync>,
63 tokenizer: Arc<Tokenizer>,
64 model_id: String,
65 metadata: Arc<GeneralMetadata>,
66 topology: Option<Topology>,
67 silent: bool,
68 config: String,
69 modules_ser: String,
70 modules_manifest: Vec<EmbeddingModulePaths>,
71 mapper: Box<dyn DeviceMapper + Send + Sync>,
72 modules: Vec<Box<dyn Module + Send + Sync>>,
73 processor: Arc<dyn Processor + Send + Sync>,
74}
75
76pub struct EmbeddingLoader {
78 inner: Box<dyn EmbeddingModelLoader>,
79 model_id: String,
80 config: EmbeddingSpecificConfig,
81 kind: ModelKind,
82 tokenizer_json: Option<String>,
83 token_source: RwLock<Option<TokenSource>>,
84 revision: RwLock<Option<String>>,
85 from_uqff: RwLock<Option<Vec<PathBuf>>>,
86 hf_cache_path: Option<PathBuf>,
87 lora_adapter_ids: Option<Vec<String>>,
88}
89
90#[derive(Default)]
91pub struct EmbeddingLoaderBuilder {
93 model_id: Option<String>,
94 config: EmbeddingSpecificConfig,
95 kind: ModelKind,
96 tokenizer_json: Option<String>,
97 hf_cache_path: Option<PathBuf>,
98 lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102pub struct EmbeddingSpecificConfig {
104 pub topology: Option<Topology>,
105 pub write_uqff: Option<PathBuf>,
106 pub from_uqff: Option<Vec<PathBuf>>,
107 pub hf_cache_path: Option<PathBuf>,
108}
109
110impl EmbeddingLoaderBuilder {
111 pub fn new(
112 config: EmbeddingSpecificConfig,
113 tokenizer_json: Option<String>,
114 model_id: Option<String>,
115 ) -> Self {
116 Self {
117 config,
118 tokenizer_json,
119 model_id,
120 kind: ModelKind::Normal,
121 hf_cache_path: None,
122 ..Default::default()
123 }
124 }
125
126 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
127 self.hf_cache_path = Some(hf_cache_path);
128 self
129 }
130
131 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
132 self.kind = ModelKind::Adapter {
133 adapter: AdapterKind::Lora,
134 };
135 self.lora_adapter_ids = Some(lora_adapter_ids);
136 self
137 }
138
139 pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
140 let loader: Box<dyn EmbeddingModelLoader> = match loader {
141 Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
142 Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
143 None => Box::new(AutoEmbeddingLoader),
144 };
145 Box::new(EmbeddingLoader {
146 inner: loader,
147 model_id: self.model_id.unwrap(),
148 config: self.config,
149 kind: self.kind,
150 tokenizer_json: self.tokenizer_json,
151 token_source: RwLock::new(None),
152 revision: RwLock::new(None),
153 from_uqff: RwLock::new(None),
154 hf_cache_path: self.hf_cache_path,
155 lora_adapter_ids: self.lora_adapter_ids,
156 })
157 }
158}
159
160impl Loader for EmbeddingLoader {
161 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
162 fn load_model_from_hf(
163 &self,
164 revision: Option<String>,
165 token_source: TokenSource,
166 dtype: &dyn TryIntoDType,
167 device: &Device,
168 silent: bool,
169 mapper: DeviceMapSetting,
170 in_situ_quant: Option<IsqType>,
171 paged_attn_config: Option<PagedAttentionConfig>,
172 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
173 let _progress_guard = ProgressScopeGuard::new(silent);
174 let cache = self
175 .hf_cache_path
176 .clone()
177 .map(Cache::new)
178 .unwrap_or_default();
179 GLOBAL_HF_CACHE.get_or_init(|| cache);
180
181 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
182 EmbeddingModelPaths,
183 &token_source,
184 revision.clone(),
185 self,
186 None,
187 None,
188 silent,
189 self.config.from_uqff.is_some()
190 );
191 if let Some(from_uqff) = self.config.from_uqff.clone() {
192 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
193 }
194 *self
195 .token_source
196 .write()
197 .expect("Failed to write to token source") = Some(token_source);
198 *self.revision.write().expect("Failed to write to revision") = revision;
199 self.load_model_from_path(
200 &paths?,
201 dtype,
202 device,
203 silent,
204 mapper,
205 in_situ_quant,
206 paged_attn_config,
207 )
208 }
209
210 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
211 fn load_model_from_path(
212 &self,
213 paths: &Box<dyn ModelPaths>,
214 dtype: &dyn TryIntoDType,
215 device: &Device,
216 silent: bool,
217 mut mapper: DeviceMapSetting,
218 in_situ_quant: Option<IsqType>,
219 mut paged_attn_config: Option<PagedAttentionConfig>,
220 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
221 let _progress_guard = ProgressScopeGuard::new(silent);
222 let config = std::fs::read_to_string(paths.get_config_filename())?;
223
224 if paged_attn_config.is_some() {
225 warn!("PagedAttention is not supported for embedding models, disabling it.");
226 paged_attn_config = None;
227 }
228
229 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
230
231 let use_nccl = mistralrs_quant::distributed::use_nccl();
232
233 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
234 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
235 let WorkerTransferData::Init { id: _, worker_rank } = payload;
236 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
237 } else if use_nccl {
238 vec![candle_core::Device::new_cuda_with_stream(0)?]
239 } else {
240 device_map::get_all_similar_devices(device)?
241 };
242 #[cfg(feature = "cuda")]
243 for device in &available_devices {
244 if let Device::Cuda(dev) = device {
245 unsafe { dev.disable_event_tracking() };
246 }
247 }
248 let device = if use_nccl {
249 available_devices[0].clone()
250 } else {
251 device.clone()
252 };
253
254 if use_nccl {
256 mapper = DeviceMapSetting::DummyNccl {
257 nm_device: available_devices[0].clone(),
258 };
259 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
260 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
262
263 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
266 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
267 let weight_pack_factor = {
268 let ser_artifacts = unsafe {
269 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
270 };
271 let mut total_pack_factors = 0;
272 let total_tensors = ser_artifacts.tensors().len();
273 for (_, artifact) in ser_artifacts.tensors() {
274 let artifact = artifact.data();
275 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
277 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
278 {
279 QuantizedSerdeType::Hqq => {
280 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
281 .pack_factor(dtype)
282 }
283 QuantizedSerdeType::Gguf => {
284 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
285 .pack_factor(dtype)
286 }
287 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
288 QuantizedSerdeType::Unquant => 1,
289 QuantizedSerdeType::Afq => {
290 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
291 .pack_factor(dtype)
292 }
293 };
294 total_pack_factors += pack_factor;
295 }
296
297 total_pack_factors / total_tensors
298 };
299
300 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
301 &config,
302 dtype,
303 weight_pack_factor,
304 None,
305 )?;
306 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
307 &config,
308 dtype,
309 weight_pack_factor,
310 None,
311 )?;
312 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
313 (
314 layer_sizes_in_bytes,
315 non_mapped_size_in_bytes,
316 layer_sizes_sum + non_mapped_size_in_bytes,
317 )
318 } else if let Some(isq) = in_situ_quant {
319 let weight_pack_factor = isq.pack_factor(dtype);
320 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
321 &config,
322 dtype,
323 weight_pack_factor,
324 None,
325 )?;
326 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
327 &config,
328 dtype,
329 weight_pack_factor,
330 None,
331 )?;
332 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
333 (
334 layer_sizes_in_bytes,
335 non_mapped_size_in_bytes,
336 layer_sizes_sum + non_mapped_size_in_bytes,
337 )
338 } else {
339 let weight_pack_factor =
341 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
342 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
343 &config,
344 dtype,
345 weight_pack_factor,
346 None,
347 )?;
348 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
349 &config,
350 dtype,
351 weight_pack_factor,
352 None,
353 )?;
354 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
355 (
356 layer_sizes_in_bytes,
357 non_mapped_size_in_bytes,
358 layer_sizes_sum + non_mapped_size_in_bytes,
359 )
360 };
361
362 let new = auto_device_map::get_device_layers(
363 &*self.inner,
364 &config,
365 self.inner.num_layers(&config)?,
366 layer_sizes_in_bytes,
367 non_mapped_size_in_bytes,
368 total_model_size_in_bytes,
369 &available_devices,
370 dtype,
371 ¶ms,
372 paged_attn_config.as_ref(),
373 )?;
374 mapper = DeviceMapSetting::Map(new);
375 }
376
377 let pipeline_mapper = mapper.into_mapper(
378 self.inner.num_layers(&config)?,
379 &device,
380 self.config.topology.as_ref(),
381 &available_devices,
382 )?;
383 let mapper = mapper.into_mapper(
384 self.inner.num_layers(&config)?,
385 &device,
386 self.config.topology.as_ref(),
387 &available_devices,
388 )?;
389 let mut layer_devices = Vec::new();
390 for layer in 0..self.inner.num_layers(&config)? {
391 let device = mapper.device_for(layer, false).cloned();
392 layer_devices.push(device);
393 }
394 let dtype = mapper.get_min_dtype(dtype)?;
395
396 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
397 if crate::using_flash_attn() {
398 once_log_info("FlashAttention is enabled.");
399 }
400
401 let topology_overrides = self
402 .config
403 .topology
404 .as_ref()
405 .map(|topology| {
406 topology
407 .pattern_overrides()
408 .into_iter()
409 .map(|(regex, layer)| ImmediateIsqOverride {
410 predicate: regex,
411 ty: layer.isq,
412 device: layer.device.clone(),
413 })
414 .collect::<Vec<_>>()
415 })
416 .unwrap_or_default();
417 let has_override_isq = topology_overrides
418 .iter()
419 .any(|override_entry| override_entry.ty.is_some());
420 let topology_requires_post_quant = self
421 .config
422 .topology
423 .as_ref()
424 .is_some_and(|topology| topology.requires_post_quantization());
425
426 let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
427
428 let mut immediate_ty = None;
429 let mut immediate_predicates = Vec::new();
430 if allow_immediate_cli {
431 immediate_ty = in_situ_quant;
432 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
433 info!("Applying ISQ to {in_situ_quant:?}");
434 if immediate_predicates.is_empty() {
435 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
436 }
437 }
438
439 let use_immediate = allow_immediate_cli || has_override_isq;
440 if use_immediate {
441 mistralrs_quant::set_immediate_isq_with_overrides(
442 immediate_ty,
443 immediate_predicates.clone(),
444 topology_overrides.clone(),
445 );
446 }
447
448 let mut loading_isq = if use_immediate {
450 false
451 } else {
452 in_situ_quant.is_some()
453 };
454 loading_isq |= topology_requires_post_quant;
455
456 let load_device = if !loading_isq {
458 loading_isq = false;
459 device.clone()
460 } else {
461 Device::Cpu
462 };
463
464 let attention_mechanism = if paged_attn_config.is_some() {
465 AttentionImplementation::PagedAttention
466 } else {
467 AttentionImplementation::Eager
468 };
469
470 let multi_progress = Arc::new(new_multi_progress());
471
472 let modules_config: Vec<_> = paths
473 .get_modules()
474 .context("Embedding models require the `modules.json` file.")?
475 .to_vec();
476 assert!(matches!(
477 modules_config.first(),
478 Some(EmbeddingModulePaths::Transformer { .. })
479 ));
480
481 let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
482 for module in &modules_config {
483 match module {
484 EmbeddingModulePaths::Transformer { .. } => (),
485 EmbeddingModulePaths::Pooling { config, .. } => {
486 let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
487 modules.push(Box::new(layer));
488 }
489 EmbeddingModulePaths::Dense { config, model, .. } => {
490 let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
491 let safetensors = unsafe { MmapedSafetensors::new(model)? };
492 let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
493 let bias = if config.bias {
494 Some(safetensors.load("linear.bias", &device, Some(dtype))?)
495 } else {
496 None
497 };
498 let (out_f, in_f) = weight.dims2()?;
499 assert_eq!((out_f, in_f), (config.out_features, config.in_features));
500 if !matches!(config.activation_function, DenseActivation::Identity) {
501 anyhow::bail!("Expected Identity activation function.");
502 }
503
504 modules.push(Box::new(Linear::new(weight, bias)));
505 }
506 EmbeddingModulePaths::Normalize { .. } => {
507 modules.push(Box::new(Normalize));
508 }
509 }
510 }
511 let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
512
513 let mut model = if use_nccl {
514 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
515 dtype,
516 &device,
517 &available_devices,
518 silent,
519 &config,
520 loading_isq,
521 self.config.from_uqff.is_some(),
522 IsqOrganization::Default,
523 &*self.inner,
524 paths.as_ref(),
525 )?;
526
527 match self.kind {
529 ModelKind::Normal => embedding_normal_model_loader_sharded!(
530 sharded_vb,
531 config,
532 self.inner,
533 mapper,
534 loading_isq,
535 device.clone(),
536 attention_mechanism,
537 multi_progress.clone(),
538 ),
539 _ => unreachable!(),
540 }
541 } else {
542 match self.kind {
543 ModelKind::Normal => embedding_normal_model_loader!(
544 paths,
545 Some(dtype),
546 &load_device,
547 layer_devices.clone(),
548 config,
549 self.inner,
550 silent,
551 mapper,
552 loading_isq,
553 self.config.from_uqff.is_some(),
554 device.clone(),
555 attention_mechanism,
556 multi_progress,
557 ),
558 _ => unreachable!(),
559 }
560 };
561
562 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
563
564 let should_serialize = self.config.write_uqff.is_some();
565 let should_quantize_pass = loading_isq;
566
567 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
568 if should_quantize_pass {
569 info!("Applying ISQ to all ranks.");
570 } else {
571 info!("Serializing existing ISQ tensors without additional quantization.");
572 }
573 model.quantize(
574 in_situ_quant,
575 device.clone(),
576 self.config.topology.as_ref(),
577 silent,
578 None,
579 IsqOrganization::Default,
580 should_quantize_pass,
581 self.config.write_uqff.as_ref(),
582 UqffFullSer {
583 tokenizer: &tokenizer,
584 template_filename: paths.get_template_filename(),
585 generation_config: paths.get_gen_conf_filename(),
586 config: config.clone(),
587 processor_filename: paths.get_processor_config(),
588 preprocessor_filename: paths.get_preprocessor_config(),
589 modules: Some(&modules_ser),
590 module_paths: Some(&modules_config),
591 },
592 Arc::new(new_multi_progress()),
593 )?;
594 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
595 model.load_from_artifacts(
596 device.clone(),
597 self.config.topology.as_ref(),
598 silent,
599 from_uqff,
600 )?;
601 }
602
603 let has_causal_attention = self.inner.has_causal_attention(&config)?;
604 let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
605 Ok(Arc::new(Mutex::new(EmbeddingPipeline {
606 model,
607 tokenizer: tokenizer.into(),
608 model_id: self.model_id.clone(),
609 metadata: Arc::new(GeneralMetadata {
610 max_seq_len,
611 llg_factory: None,
612 is_xlora: false,
613 no_prefix_cache: false,
614 num_hidden_layers: 1, eos_tok: vec![],
616 kind: ModelKind::Normal,
617 no_kv_cache: true, activation_dtype: dtype,
619 sliding_window: None,
620 cache_config: None,
621 cache_engine: None,
622 model_metadata: None,
623 modalities: Modalities {
624 input: vec![SupportedModality::Text],
625 output: vec![SupportedModality::Embedding],
626 },
627 }),
628 topology: self.config.topology.clone(),
629 silent,
630 config,
631 modules_ser,
632 modules_manifest: modules_config,
633 mapper: pipeline_mapper,
634 modules,
635 processor: Arc::new(EmbeddingProcessor {
636 has_causal_attention,
637 }),
638 })))
639 }
640
641 fn get_id(&self) -> String {
642 self.model_id.to_string()
643 }
644
645 fn get_kind(&self) -> ModelKind {
646 self.kind.clone()
647 }
648}
649
650impl PreProcessingMixin for EmbeddingPipeline {
651 fn get_processor(&self) -> Arc<dyn Processor> {
652 self.processor.clone()
653 }
654 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
655 None
656 }
657 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
658 None
659 }
660}
661
662impl IsqPipelineMixin for EmbeddingPipeline {
663 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
664 let device = self.device().clone();
665 self.model
666 .quantize(
667 Some(dtype),
668 device,
669 self.topology.as_ref(),
670 self.silent,
671 None,
672 IsqOrganization::Default,
673 true,
674 None,
675 UqffFullSer {
676 tokenizer: &self.tokenizer,
677 template_filename: &None,
678 generation_config: None,
679 config: self.config.clone(),
680 processor_filename: &None,
681 preprocessor_filename: &None,
682 modules: Some(&self.modules_ser),
683 module_paths: Some(&self.modules_manifest),
684 },
685 Arc::new(new_multi_progress()),
686 )
687 .map_err(anyhow::Error::msg)
688 }
689}
690
691impl CacheManagerMixin for EmbeddingPipeline {
692 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
693 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
694 fn set_none_cache(
695 &self,
696 _seqs: &mut [&mut Sequence],
697 _reset_non_granular: bool,
698 _modify_draft_cache: bool,
699 _load_preallocated_cache: bool,
700 ) {
701 }
702 fn cache(&self) -> &EitherCache {
703 unreachable!()
704 }
705}
706
707impl MetadataMixin for EmbeddingPipeline {
708 fn device(&self) -> Device {
709 self.model.device().clone()
710 }
711 fn get_metadata(&self) -> Arc<GeneralMetadata> {
712 self.metadata.clone()
713 }
714 fn name(&self) -> String {
715 self.model_id.clone()
716 }
717 fn reset_non_granular_state(&self) {}
718 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
719 Some(self.tokenizer.clone())
720 }
721 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
722 Some(&*self.mapper)
723 }
724}
725
726#[async_trait::async_trait]
727impl Pipeline for EmbeddingPipeline {
728 fn forward_inputs(
729 &mut self,
730 inputs: Box<dyn Any>,
731 _return_raw_logits: bool,
732 ) -> candle_core::Result<ForwardInputsResult> {
733 let ModelInputs {
734 input_ids,
735 flash_meta,
736 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
737
738 let mut xs = self.model.forward(&input_ids, &flash_meta)?;
739 for module in &self.modules {
740 xs = module.forward(&xs)?;
741 }
742
743 Ok(ForwardInputsResult::Embeddings { embeddings: xs })
744 }
745 async fn sample_causal_gen(
746 &self,
747 seqs: &mut [&mut Sequence],
748 logits: Vec<Tensor>,
749 prefix_cacher: &mut PrefixCacheManagerV2,
750 disable_eos_stop: bool,
751 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
752 ) -> Result<(), candle_core::Error> {
753 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
754 }
755 fn category(&self) -> ModelCategory {
756 ModelCategory::Embedding
757 }
758}
759
760impl AnyMoePipelineMixin for EmbeddingPipeline {}