1use super::isq::{UqffFullSer, WeightLoadingMode, WeightLoadingState};
2use super::{
3 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5 ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, use_ring, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29 progress::{new_multi_progress, ProgressScopeGuard},
30 tokens::get_token,
31 varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36 get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
37 GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use hanzo_ml::{Device, Tensor};
42use hanzo_nn::{Linear, Module};
43use hanzo_quant::log::once_log_info;
44use hanzo_quant::safetensors::MmapedSafetensors;
45use hanzo_quant::{
46 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
47};
48use hf_hub::Cache;
49use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{debug, info, trace, warn};
60
61pub struct EmbeddingPipeline {
62 model: Box<dyn EmbeddingModel + Send + Sync>,
63 tokenizer: Arc<Tokenizer>,
64 model_id: String,
65 metadata: Arc<GeneralMetadata>,
66 topology: Option<Topology>,
67 silent: bool,
68 config: String,
69 modules_ser: String,
70 modules_manifest: Vec<EmbeddingModulePaths>,
71 mapper: Box<dyn DeviceMapper + Send + Sync>,
72 modules: Vec<Box<dyn Module + Send + Sync>>,
73 processor: Arc<dyn Processor + Send + Sync>,
74}
75
76pub struct EmbeddingLoader {
78 inner: Box<dyn EmbeddingModelLoader>,
79 model_id: String,
80 config: EmbeddingSpecificConfig,
81 kind: ModelKind,
82 tokenizer_json: Option<String>,
83 token_source: RwLock<Option<TokenSource>>,
84 revision: RwLock<Option<String>>,
85 from_uqff: RwLock<Option<Vec<PathBuf>>>,
86 hf_cache_path: Option<PathBuf>,
87 lora_adapter_ids: Option<Vec<String>>,
88 load_context: EmbeddingLoadContext,
89}
90
91#[derive(Clone, Copy, Default)]
92pub(crate) enum EmbeddingLoadContext {
93 #[default]
94 Primary,
95 Search,
96}
97
98impl EmbeddingLoadContext {
99 fn weight_target(self) -> &'static str {
100 match self {
101 Self::Primary => "model",
102 Self::Search => "search embedding model",
103 }
104 }
105}
106
107#[derive(Default)]
108pub struct EmbeddingLoaderBuilder {
110 model_id: Option<String>,
111 config: EmbeddingSpecificConfig,
112 kind: ModelKind,
113 tokenizer_json: Option<String>,
114 hf_cache_path: Option<PathBuf>,
115 lora_adapter_ids: Option<Vec<String>>,
116 load_context: EmbeddingLoadContext,
117}
118
119#[derive(Clone, Default)]
120pub struct EmbeddingSpecificConfig {
122 pub topology: Option<Topology>,
123 pub write_uqff: Option<PathBuf>,
124 pub from_uqff: Option<Vec<PathBuf>>,
125 pub hf_cache_path: Option<PathBuf>,
126}
127
128impl EmbeddingLoaderBuilder {
129 pub fn new(
130 config: EmbeddingSpecificConfig,
131 tokenizer_json: Option<String>,
132 model_id: Option<String>,
133 ) -> Self {
134 Self {
135 config,
136 tokenizer_json,
137 model_id,
138 kind: ModelKind::Normal,
139 hf_cache_path: None,
140 ..Default::default()
141 }
142 }
143
144 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
145 self.hf_cache_path = Some(hf_cache_path);
146 self
147 }
148
149 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
150 self.kind = ModelKind::Adapter {
151 adapter: AdapterKind::Lora,
152 };
153 self.lora_adapter_ids = Some(lora_adapter_ids);
154 self
155 }
156
157 pub(crate) fn with_load_context(mut self, load_context: EmbeddingLoadContext) -> Self {
158 self.load_context = load_context;
159 self
160 }
161
162 pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
163 let loader: Box<dyn EmbeddingModelLoader> = match loader {
164 Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
165 Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
166 None => Box::new(AutoEmbeddingLoader),
167 };
168 Box::new(EmbeddingLoader {
169 inner: loader,
170 model_id: self.model_id.unwrap(),
171 config: self.config,
172 kind: self.kind,
173 tokenizer_json: self.tokenizer_json,
174 token_source: RwLock::new(None),
175 revision: RwLock::new(None),
176 from_uqff: RwLock::new(None),
177 hf_cache_path: self.hf_cache_path,
178 lora_adapter_ids: self.lora_adapter_ids,
179 load_context: self.load_context,
180 })
181 }
182}
183
184impl Loader for EmbeddingLoader {
185 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
186 fn load_model_from_hf(
187 &self,
188 revision: Option<String>,
189 token_source: TokenSource,
190 dtype: &dyn TryIntoDType,
191 device: &Device,
192 silent: bool,
193 mapper: DeviceMapSetting,
194 in_situ_quant: Option<IsqType>,
195 paged_attn_config: Option<PagedAttentionConfig>,
196 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
197 let _progress_guard = ProgressScopeGuard::new(silent);
198 let cache = self
199 .hf_cache_path
200 .clone()
201 .map(Cache::new)
202 .unwrap_or_default();
203 GLOBAL_HF_CACHE.get_or_init(|| cache);
204
205 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
206 EmbeddingModelPaths,
207 &token_source,
208 revision.clone(),
209 self,
210 None,
211 None,
212 silent,
213 self.config.from_uqff.is_some()
214 );
215 *self
216 .token_source
217 .write()
218 .expect("Failed to write to token source") = Some(token_source);
219 *self.revision.write().expect("Failed to write to revision") = revision.clone();
220 if let Some(from_uqff) = self.config.from_uqff.clone() {
221 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
222 }
223 self.load_model_from_path(
224 &paths?,
225 dtype,
226 device,
227 silent,
228 mapper,
229 in_situ_quant,
230 paged_attn_config,
231 )
232 }
233
234 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
235 fn load_model_from_path(
236 &self,
237 paths: &Box<dyn ModelPaths>,
238 dtype: &dyn TryIntoDType,
239 device: &Device,
240 silent: bool,
241 mut mapper: DeviceMapSetting,
242 in_situ_quant: Option<IsqType>,
243 mut paged_attn_config: Option<PagedAttentionConfig>,
244 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
245 let _progress_guard = ProgressScopeGuard::new(silent);
246 let config = std::fs::read_to_string(paths.get_config_filename())?;
247
248 if paged_attn_config.is_some() {
249 warn!("PagedAttention is not supported for embedding models, disabling it.");
250 paged_attn_config = None;
251 }
252
253 debug!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
254
255 let use_nccl = hanzo_quant::distributed::use_nccl();
256
257 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
258 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
259 let WorkerTransferData::Init { id: _, worker_rank } = payload;
260 vec![hanzo_ml::Device::new_cuda_with_stream(worker_rank + 1)?]
261 } else if use_nccl || use_ring() {
262 vec![hanzo_ml::Device::new_cuda_with_stream(0)?]
263 } else {
264 device_map::get_all_similar_devices(device)?
265 };
266 #[cfg(feature = "cuda")]
267 for device in &available_devices {
268 if let Device::Cuda(dev) = device {
269 unsafe { dev.disable_event_tracking() };
270 }
271 }
272 let device = if use_nccl || use_ring() {
273 available_devices[0].clone()
274 } else {
275 device.clone()
276 };
277
278 if use_nccl || use_ring() {
280 mapper = DeviceMapSetting::DummyNccl {
281 nm_device: available_devices[0].clone(),
282 };
283 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
284 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
286
287 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
290 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
291 let weight_pack_factor = {
292 let ser_artifacts =
293 unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(serialized)? };
294 let mut total_pack_factors = 0;
295 let total_tensors = ser_artifacts.tensors().len();
296 for (_, artifact) in ser_artifacts.tensors() {
297 let artifact = artifact.data();
298 let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
300 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
301 {
302 QuantizedSerdeType::Hqq => {
303 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
304 .pack_factor(dtype)
305 }
306 QuantizedSerdeType::Gguf => {
307 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
308 .pack_factor(dtype)
309 }
310 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
311 QuantizedSerdeType::Unquant => 1,
312 QuantizedSerdeType::Afq => {
313 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
314 .pack_factor(dtype)
315 }
316 QuantizedSerdeType::F8Q8 => IsqType::F8Q8.pack_factor(dtype),
317 QuantizedSerdeType::Mxfp4 => IsqType::MXFP4.pack_factor(dtype),
318 };
319 total_pack_factors += pack_factor;
320 }
321
322 total_pack_factors / total_tensors
323 };
324
325 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
326 &config,
327 dtype,
328 weight_pack_factor,
329 None,
330 )?;
331 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
332 &config,
333 dtype,
334 weight_pack_factor,
335 None,
336 )?;
337 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338 (
339 layer_sizes_in_bytes,
340 non_mapped_size_in_bytes,
341 layer_sizes_sum + non_mapped_size_in_bytes,
342 )
343 } else if let Some(isq) = in_situ_quant {
344 let weight_pack_factor = isq.pack_factor(dtype);
345 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
346 &config,
347 dtype,
348 weight_pack_factor,
349 None,
350 )?;
351 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
352 &config,
353 dtype,
354 weight_pack_factor,
355 None,
356 )?;
357 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
358 (
359 layer_sizes_in_bytes,
360 non_mapped_size_in_bytes,
361 layer_sizes_sum + non_mapped_size_in_bytes,
362 )
363 } else {
364 let weight_pack_factor =
366 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
367 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
368 &config,
369 dtype,
370 weight_pack_factor,
371 None,
372 )?;
373 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
374 &config,
375 dtype,
376 weight_pack_factor,
377 None,
378 )?;
379 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
380 (
381 layer_sizes_in_bytes,
382 non_mapped_size_in_bytes,
383 layer_sizes_sum + non_mapped_size_in_bytes,
384 )
385 };
386
387 let new = auto_device_map::get_device_layers(
388 &*self.inner,
389 &config,
390 self.inner.num_layers(&config)?,
391 layer_sizes_in_bytes,
392 non_mapped_size_in_bytes,
393 total_model_size_in_bytes,
394 &available_devices,
395 dtype,
396 ¶ms,
397 paged_attn_config.as_ref(),
398 )?;
399 mapper = DeviceMapSetting::Map(new);
400 }
401
402 let pipeline_mapper = mapper.into_mapper(
403 self.inner.num_layers(&config)?,
404 &device,
405 self.config.topology.as_ref(),
406 &available_devices,
407 )?;
408 let mapper = mapper.into_mapper(
409 self.inner.num_layers(&config)?,
410 &device,
411 self.config.topology.as_ref(),
412 &available_devices,
413 )?;
414 let mut layer_devices = Vec::new();
415 for layer in 0..self.inner.num_layers(&config)? {
416 let device = mapper.device_for(layer, false).cloned();
417 layer_devices.push(device);
418 }
419 let dtype = mapper.get_min_dtype(dtype)?;
420
421 trace!("Model config: {:?}", self.inner.get_config_repr(&config)?);
422 if crate::using_flash_attn() {
423 once_log_info("FlashAttention is enabled.");
424 }
425
426 let topology_overrides = self
427 .config
428 .topology
429 .as_ref()
430 .map(|topology| {
431 topology
432 .pattern_overrides()
433 .into_iter()
434 .map(|(regex, layer)| ImmediateIsqOverride {
435 predicate: regex,
436 ty: layer.isq,
437 device: layer.device.clone(),
438 })
439 .collect::<Vec<_>>()
440 })
441 .unwrap_or_default();
442 let has_override_isq = topology_overrides
443 .iter()
444 .any(|override_entry| override_entry.ty.is_some());
445 let topology_requires_post_quant = self
446 .config
447 .topology
448 .as_ref()
449 .is_some_and(|topology| topology.requires_post_quantization());
450
451 let allow_immediate_cli = in_situ_quant.is_some();
452
453 let mut immediate_ty = None;
454 let mut immediate_predicates = Vec::new();
455 if allow_immediate_cli {
456 immediate_ty = in_situ_quant;
457 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
458 info!("Applying ISQ to {in_situ_quant:?}");
459 if immediate_predicates.is_empty() {
460 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
461 }
462 }
463
464 let use_immediate = allow_immediate_cli || has_override_isq;
465 if use_immediate {
466 let (pool, num_threads) = hanzo_quant::create_isq_thread_pool(immediate_ty);
467 info!("Applying immediate ISQ in parallel on {num_threads} threads.");
468 hanzo_quant::set_immediate_isq_with_pool(
469 immediate_ty,
470 immediate_predicates.clone(),
471 topology_overrides.clone(),
472 pool,
473 );
474 }
475
476 let mut loading_isq = if use_immediate {
478 false
479 } else {
480 in_situ_quant.is_some()
481 };
482 loading_isq |= topology_requires_post_quant;
483 loading_isq |= self.config.from_uqff.is_some();
484
485 let load_device = if !loading_isq {
491 loading_isq = false;
492 if use_immediate && !crate::utils::normal::is_integrated_gpu(&device) {
493 Device::Cpu
494 } else {
495 device.clone()
496 }
497 } else {
498 Device::Cpu
499 };
500
501 let attention_mechanism = if paged_attn_config.is_some() {
502 AttentionImplementation::PagedAttention
503 } else {
504 AttentionImplementation::Eager
505 };
506
507 let multi_progress = Arc::new(new_multi_progress());
508
509 let modules_config: Vec<_> = paths
510 .get_modules()
511 .context("Embedding models require the `modules.json` file.")?
512 .to_vec();
513 assert!(matches!(
514 modules_config.first(),
515 Some(EmbeddingModulePaths::Transformer { .. })
516 ));
517
518 let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
519 for module in &modules_config {
520 match module {
521 EmbeddingModulePaths::Transformer { .. } => (),
522 EmbeddingModulePaths::Pooling { config, .. } => {
523 let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
524 modules.push(Box::new(layer));
525 }
526 EmbeddingModulePaths::Dense { config, model, .. } => {
527 let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
528 let safetensors = unsafe { MmapedSafetensors::new(model)? };
529 let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
530 let bias = if config.bias {
531 Some(safetensors.load("linear.bias", &device, Some(dtype))?)
532 } else {
533 None
534 };
535 let (out_f, in_f) = weight.dims2()?;
536 assert_eq!((out_f, in_f), (config.out_features, config.in_features));
537 if !matches!(config.activation_function, DenseActivation::Identity) {
538 anyhow::bail!("Expected Identity activation function.");
539 }
540
541 modules.push(Box::new(Linear::new(weight, bias)));
542 }
543 EmbeddingModulePaths::Normalize { .. } => {
544 modules.push(Box::new(Normalize));
545 }
546 }
547 }
548 let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
549
550 info!(
551 "{}",
552 WeightLoadingMode::from(WeightLoadingState {
553 from_uqff: self.config.from_uqff.is_some(),
554 loading_isq,
555 immediate_isq: use_immediate,
556 write_uqff: self.config.write_uqff.is_some(),
557 })
558 .message(self.load_context.weight_target())
559 );
560
561 let mut model = if use_nccl || use_ring() {
562 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
563 dtype,
564 &device,
565 &available_devices,
566 silent,
567 &config,
568 loading_isq,
569 self.config.from_uqff.is_some(),
570 IsqOrganization::Default,
571 &*self.inner,
572 paths.as_ref(),
573 )?;
574
575 match self.kind {
577 ModelKind::Normal => embedding_normal_model_loader_sharded!(
578 sharded_vb,
579 config,
580 self.inner,
581 mapper,
582 loading_isq,
583 device.clone(),
584 attention_mechanism,
585 multi_progress.clone(),
586 ),
587 _ => unreachable!(),
588 }
589 } else {
590 match self.kind {
591 ModelKind::Normal => embedding_normal_model_loader!(
592 paths,
593 Some(dtype),
594 &load_device,
595 layer_devices.clone(),
596 config,
597 self.inner,
598 silent,
599 mapper,
600 loading_isq,
601 self.config.from_uqff.is_some(),
602 device.clone(),
603 attention_mechanism,
604 multi_progress,
605 ),
606 _ => unreachable!(),
607 }
608 };
609
610 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
611
612 let should_serialize = self.config.write_uqff.is_some();
613 let should_quantize_pass = loading_isq;
614
615 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
616 if should_quantize_pass {
617 debug!("Applying ISQ to all ranks.");
618 } else {
619 debug!("Serializing existing ISQ tensors without additional quantization.");
620 }
621 model.quantize(
622 in_situ_quant,
623 device.clone(),
624 self.config.topology.as_ref(),
625 silent,
626 None,
627 IsqOrganization::Default,
628 should_quantize_pass,
629 self.config.write_uqff.as_ref(),
630 UqffFullSer {
631 tokenizer: &tokenizer,
632 template_filename: paths.get_template_filename(),
633 generation_config: paths.get_gen_conf_filename(),
634 config: config.clone(),
635 processor_filename: paths.get_processor_config(),
636 preprocessor_filename: paths.get_preprocessor_config(),
637 modules: Some(&modules_ser),
638 module_paths: Some(&modules_config),
639 },
640 Arc::new(new_multi_progress()),
641 )?;
642 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
643 model.load_from_artifacts(
644 device.clone(),
645 self.config.topology.as_ref(),
646 silent,
647 from_uqff,
648 )?;
649 }
650
651 let has_causal_attention = self.inner.has_causal_attention(&config)?;
652 let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
653 Ok(Arc::new(Mutex::new(EmbeddingPipeline {
654 model,
655 tokenizer: tokenizer.into(),
656 model_id: self.model_id.clone(),
657 metadata: Arc::new(GeneralMetadata {
658 max_seq_len,
659 llg_factory: None,
660 is_xlora: false,
661 no_prefix_cache: false,
662 num_hidden_layers: 1, eos_tok: vec![],
664 kind: ModelKind::Normal,
665 no_kv_cache: true, activation_dtype: dtype,
667 sliding_window: None,
668 cache_config: None,
669 cache_engine: None,
670 model_metadata: None,
671 modalities: Modalities {
672 input: vec![SupportedModality::Text],
673 output: vec![SupportedModality::Embedding],
674 },
675 }),
676 topology: self.config.topology.clone(),
677 silent,
678 config,
679 modules_ser,
680 modules_manifest: modules_config,
681 mapper: pipeline_mapper,
682 modules,
683 processor: Arc::new(EmbeddingProcessor {
684 has_causal_attention,
685 }),
686 })))
687 }
688
689 fn get_id(&self) -> String {
690 self.model_id.to_string()
691 }
692
693 fn get_kind(&self) -> ModelKind {
694 self.kind.clone()
695 }
696}
697
698impl PreProcessingMixin for EmbeddingPipeline {
699 fn get_processor(&self) -> Arc<dyn Processor> {
700 self.processor.clone()
701 }
702 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
703 None
704 }
705 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
706 None
707 }
708}
709
710impl IsqPipelineMixin for EmbeddingPipeline {
711 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
712 let device = self.device().clone();
713 self.model
714 .quantize(
715 Some(dtype),
716 device,
717 self.topology.as_ref(),
718 self.silent,
719 None,
720 IsqOrganization::Default,
721 true,
722 None,
723 UqffFullSer {
724 tokenizer: &self.tokenizer,
725 template_filename: &None,
726 generation_config: None,
727 config: self.config.clone(),
728 processor_filename: &None,
729 preprocessor_filename: &None,
730 modules: Some(&self.modules_ser),
731 module_paths: Some(&self.modules_manifest),
732 },
733 Arc::new(new_multi_progress()),
734 )
735 .map_err(anyhow::Error::msg)
736 }
737}
738
739impl CacheManagerMixin for EmbeddingPipeline {
740 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
741 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
742 fn set_none_cache(
743 &self,
744 _seqs: &mut [&mut Sequence],
745 _reset_non_granular: bool,
746 _modify_draft_cache: bool,
747 _load_preallocated_cache: bool,
748 ) {
749 }
750 fn cache(&self) -> &EitherCache {
751 unreachable!()
752 }
753}
754
755impl MetadataMixin for EmbeddingPipeline {
756 fn device(&self) -> Device {
757 self.model.device().clone()
758 }
759 fn get_metadata(&self) -> Arc<GeneralMetadata> {
760 self.metadata.clone()
761 }
762 fn name(&self) -> String {
763 self.model_id.clone()
764 }
765 fn reset_non_granular_state(&self) {}
766 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
767 Some(self.tokenizer.clone())
768 }
769 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
770 Some(&*self.mapper)
771 }
772}
773
774#[async_trait::async_trait]
775impl Pipeline for EmbeddingPipeline {
776 fn forward_inputs(
777 &mut self,
778 inputs: Box<dyn Any>,
779 _return_raw_logits: bool,
780 ) -> hanzo_ml::Result<ForwardInputsResult> {
781 let ModelInputs {
782 input_ids,
783 flash_meta,
784 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
785
786 let mut xs = self.model.forward(&input_ids, &flash_meta)?;
787 for module in &self.modules {
788 xs = module.forward(&xs)?;
789 }
790
791 Ok(ForwardInputsResult::Embeddings { embeddings: xs })
792 }
793 async fn sample_causal_gen(
794 &self,
795 seqs: &mut [&mut Sequence],
796 logits: Vec<Tensor>,
797 prefix_cacher: &mut PrefixCacheManagerV2,
798 disable_eos_stop: bool,
799 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
800 ) -> Result<(), hanzo_ml::Error> {
801 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
802 }
803 fn category(&self) -> ModelCategory {
804 ModelCategory::Embedding
805 }
806}
807
808impl AnyMoePipelineMixin for EmbeddingPipeline {}