1use std::{
2 fmt::{self, Debug, Display},
3 path::PathBuf,
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 attention::ATTENTION_CHUNK_SIZE,
10 embedding_models::{
11 embedding_gemma::{EmbeddingGemma, EmbeddingGemmaConfig},
12 qwen3_embedding::{Config as Qwen3EmbeddingConfig, Model as Qwen3EmbeddingModel},
13 },
14 matformer::MatformerSliceConfig,
15 pipeline::{loaders::auto_device_map::NonMappedSubModel, NormalLoadingMetadata},
16};
17
18use crate::{
19 amoe::AnyMoeBaseModelMixin,
20 device_map::DeviceMapper,
21 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
22 pipeline::{isq::IsqModelLoader, text_models_inputs_processor::FlashParams, IsqModel},
23 utils::varbuilder_utils::DeviceForLoadTensor,
24};
25use anyhow::Result;
26use hanzo_ml::{DType, Device, Tensor};
27use hanzo_quant::log::once_log_debug;
28
29use hanzo_quant::ShardedVarBuilder;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
35
36use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
37
38pub trait EmbeddingModel: IsqModel + AnyMoeBaseModelMixin {
39 #[allow(clippy::too_many_arguments)]
40 fn forward(&self, input_ids: &Tensor, flash_params: &FlashParams) -> hanzo_ml::Result<Tensor>;
41 fn device(&self) -> &Device;
42}
43
44pub trait EmbeddingModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
45 fn load(
46 &self,
47 config: &str,
48 vb: ShardedVarBuilder,
49 normal_loading_metadata: NormalLoadingMetadata,
50 attention_mechanism: AttentionImplementation,
51 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>>;
52 fn is_gptx(&self, config: &str) -> Result<bool>;
53 fn has_causal_attention(&self, config: &str) -> Result<bool>;
54 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
55 fn get_device_for_tensor(
56 &self,
57 config: &str,
58 _mapper: &dyn DeviceMapper,
59 loading_isq: bool,
60 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
61 if loading_isq {
62 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
63 } else {
64 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
65 let num_layers = self.model_config(config)?.num_layers();
66 let closure = move |name: String| {
67 if let Some(captures) = re.captures(&name) {
68 captures
69 .get(1)
70 .and_then(|m| m.as_str().parse::<usize>().ok())
71 .map(|l| l.min(num_layers))
72 .map(DeviceForLoadTensor::Idx)
73 .unwrap_or(DeviceForLoadTensor::Base)
74 } else {
75 DeviceForLoadTensor::Base
76 }
77 };
78
79 Ok(Arc::new(closure))
80 }
81 }
82}
83
84#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
85#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
86pub enum EmbeddingLoaderType {
88 #[serde(rename = "embeddinggemma")]
89 EmbeddingGemma,
90 #[serde(rename = "qwen3embedding")]
91 Qwen3Embedding,
92}
93
94impl EmbeddingLoaderType {
96 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
97 match name {
98 "Gemma3TextModel" => Ok(Self::EmbeddingGemma),
99 "Qwen3ForCausalLM" => Ok(Self::Qwen3Embedding),
100 other => anyhow::bail!(
101 "Unsupported Hugging Face Transformers model class `{other}`. Please raise an issue."
102 ),
103 }
104 }
105}
106
107impl FromStr for EmbeddingLoaderType {
108 type Err = String;
109 fn from_str(s: &str) -> Result<Self, Self::Err> {
110 match s {
111 "embeddinggemma" => Ok(Self::EmbeddingGemma),
112 "qwen3embedding" => Ok(Self::Qwen3Embedding),
113 a => Err(format!(
114 "Unknown architecture `{a}`. Possible architectures: `embeddinggemma`, `qwen3embedding`."
115 )),
116 }
117 }
118}
119
120impl Display for EmbeddingLoaderType {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Self::EmbeddingGemma => write!(f, "embeddinggemma"),
124 Self::Qwen3Embedding => write!(f, "qwen3embedding"),
125 }
126 }
127}
128
129#[derive(Clone, Debug, Deserialize)]
130pub enum EmbeddingModulePaths {
131 Transformer {
132 path: String,
133 },
134 Pooling {
135 path: String,
136 config: PathBuf,
137 },
138 Dense {
139 path: String,
140 config: PathBuf,
141 model: PathBuf,
142 },
143 Normalize {
144 path: String,
145 },
146}
147
148impl EmbeddingModulePaths {
149 pub fn serialize_modules(modules: &[EmbeddingModulePaths]) -> String {
150 #[derive(Serialize)]
151 struct OutputModule {
152 idx: usize,
153 name: String,
154 path: String,
155 #[serde(rename = "type")]
156 ty: String,
157 }
158
159 let mapped: Vec<OutputModule> = modules
160 .iter()
161 .enumerate()
162 .map(|(i, m)| {
163 let (path, ty) = match m {
164 EmbeddingModulePaths::Transformer { path } => (
165 path.clone(),
166 "sentence_transformers.models.Transformer".to_string(),
167 ),
168 EmbeddingModulePaths::Pooling { path, .. } => (
169 path.clone(),
170 "sentence_transformers.models.Pooling".to_string(),
171 ),
172 EmbeddingModulePaths::Dense { path, .. } => (
173 path.clone(),
174 "sentence_transformers.models.Dense".to_string(),
175 ),
176 EmbeddingModulePaths::Normalize { path } => (
177 path.clone(),
178 "sentence_transformers.models.Normalize".to_string(),
179 ),
180 };
181
182 OutputModule {
183 idx: i,
184 name: i.to_string(),
185 path,
186 ty,
187 }
188 })
189 .collect();
190
191 serde_json::to_string_pretty(&mapped).unwrap()
192 }
193}
194
195#[derive(Debug, Deserialize)]
196pub struct EmbeddingModule {
197 pub path: String,
198 #[serde(rename = "type", deserialize_with = "deserialize_module_type")]
199 pub ty: EmbeddingModuleType,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum EmbeddingModuleType {
204 Transformer,
205 Pooling,
206 Dense,
207 Normalize,
208}
209
210fn deserialize_module_type<'de, D>(deserializer: D) -> Result<EmbeddingModuleType, D::Error>
211where
212 D: Deserializer<'de>,
213{
214 struct ModuleTypeVisitor;
215
216 impl<'de> Visitor<'de> for ModuleTypeVisitor {
217 type Value = EmbeddingModuleType;
218
219 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
220 f.write_str("a sentence-transformers module type string")
221 }
222
223 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
224 where
225 E: serde::de::Error,
226 {
227 let last = v.rsplit('.').next().unwrap_or(v).to_ascii_lowercase();
229 match last.as_str() {
230 "transformer" => Ok(EmbeddingModuleType::Transformer),
231 "pooling" => Ok(EmbeddingModuleType::Pooling),
232 "dense" => Ok(EmbeddingModuleType::Dense),
233 "normalize" => Ok(EmbeddingModuleType::Normalize),
234 _ => Err(E::invalid_value(
235 serde::de::Unexpected::Str(v),
236 &"Transformer/Pooling/Dense/Normalize",
237 )),
238 }
239 }
240 }
241
242 deserializer.deserialize_str(ModuleTypeVisitor)
243}
244
245macro_rules! bias_if {
246 ($cond:expr, $size:expr) => {
247 if $cond {
248 $size
249 } else {
250 0
251 }
252 };
253}
254
255pub struct AutoEmbeddingLoader;
257
258#[derive(Deserialize)]
259struct AutoEmbeddingLoaderConfig {
260 architectures: Vec<String>,
261}
262
263impl AutoEmbeddingLoader {
264 fn get_loader(config: &str) -> Result<Box<dyn EmbeddingModelLoader>> {
265 let auto_cfg: AutoEmbeddingLoaderConfig = serde_json::from_str(config)?;
266 if auto_cfg.architectures.len() != 1 {
267 anyhow::bail!("Expected to have one name for `architectures` config field.")
268 }
269
270 let name = &auto_cfg.architectures[0];
271
272 let tp = EmbeddingLoaderType::from_causal_lm_name(name)?;
273
274 once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
275
276 match tp {
277 EmbeddingLoaderType::EmbeddingGemma => Ok(Box::new(EmbeddingGemmaLoader)),
278 EmbeddingLoaderType::Qwen3Embedding => Ok(Box::new(Qwen3EmbeddingLoader)),
279 }
280 }
281}
282
283impl EmbeddingModelLoader for AutoEmbeddingLoader {
284 fn load(
285 &self,
286 config: &str,
287 vb: ShardedVarBuilder,
288 normal_loading_metadata: NormalLoadingMetadata,
289 attention_mechanism: AttentionImplementation,
290 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
291 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
292 }
293 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
294 Self::get_loader(config)?.get_config_repr(config)
295 }
296 fn has_causal_attention(&self, config: &str) -> Result<bool> {
297 Self::get_loader(config)?.has_causal_attention(config)
298 }
299 fn is_gptx(&self, config: &str) -> Result<bool> {
300 Self::get_loader(config)?.is_gptx(config)
301 }
302}
303
304impl IsqModelLoader for AutoEmbeddingLoader {
305 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
306 Self::get_loader(config)?.immediate_isq_predicates(config)
307 }
308 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
309 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
310 }
311 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
312 Self::get_loader(config)?.isq_layer_regexes(config)
313 }
314 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
315 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
316 }
317}
318
319impl DeviceMappedModelLoader for AutoEmbeddingLoader {
320 fn non_mapped_size_in_bytes(
321 &self,
322 config: &str,
323 dtype: DType,
324 weight_pack_factor: usize,
325 _matformer_config: Option<&MatformerSliceConfig>,
326 ) -> Result<usize> {
327 Self::get_loader(config)?.non_mapped_size_in_bytes(
328 config,
329 dtype,
330 weight_pack_factor,
331 _matformer_config,
332 )
333 }
334 fn num_layers(&self, config: &str) -> Result<usize> {
335 Self::get_loader(config)?.num_layers(config)
336 }
337 fn layer_sizes_in_bytes(
338 &self,
339 config: &str,
340 dtype: DType,
341 weight_pack_factor: usize,
342 _matformer_config: Option<&MatformerSliceConfig>,
343 ) -> Result<Vec<usize>> {
344 Self::get_loader(config)?.layer_sizes_in_bytes(
345 config,
346 dtype,
347 weight_pack_factor,
348 _matformer_config,
349 )
350 }
351 fn mapped_max_act_size_elems(
352 &self,
353 config: &str,
354 params: &super::AutoDeviceMapParams,
355 ) -> Result<usize> {
356 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
357 }
358 fn non_mapped_max_act_size_elems(
359 &self,
360 _config: &str,
361 _params: &AutoDeviceMapParams,
362 ) -> Result<usize> {
363 Ok(0)
364 }
365 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
366 Self::get_loader(config)?.model_config(config)
367 }
368}
369
370pub struct EmbeddingGemmaLoader;
374
375impl EmbeddingModelLoader for EmbeddingGemmaLoader {
376 fn load(
377 &self,
378 config: &str,
379 vb: ShardedVarBuilder,
380 normal_loading_metadata: NormalLoadingMetadata,
381 attention_mechanism: AttentionImplementation,
382 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
383 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
384
385 Ok(Box::new(EmbeddingGemma::new(
386 &cfg,
387 vb,
388 self.is_gptx(config)?,
389 normal_loading_metadata,
390 attention_mechanism,
391 )?))
392 }
393 fn is_gptx(&self, _: &str) -> Result<bool> {
394 Ok(true)
395 }
396 fn has_causal_attention(&self, _: &str) -> Result<bool> {
397 Ok(false)
398 }
399 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
400 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
401 Ok(Box::new(cfg))
402 }
403}
404
405impl IsqModelLoader for EmbeddingGemmaLoader {
406 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
407 Ok(vec![
408 Regex::new(r"lm_head\.(weight|bias)$")?,
409 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
411 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
412 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
413 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
414 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
416 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
417 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
418 ])
419 }
420 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
421 Ok(vec![
422 Regex::new(r"lm_head\.(weight|bias)$")?,
423 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
425 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
426 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
427 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
428 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
430 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
431 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
432 ])
433 }
434}
435
436impl DeviceMappedModelLoader for EmbeddingGemmaLoader {
437 fn mapped_max_act_size_elems(
438 &self,
439 config: &str,
440 params: &AutoDeviceMapParams,
441 ) -> Result<usize> {
442 let AutoDeviceMapParams::Text {
443 max_seq_len,
444 max_batch_size,
445 } = params
446 else {
447 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
448 };
449
450 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
451
452 Ok(
453 max_batch_size
454 * cfg.num_attention_heads
455 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
456 )
457 }
458
459 fn non_mapped_max_act_size_elems(
460 &self,
461 _config: &str,
462 _params: &AutoDeviceMapParams,
463 ) -> Result<usize> {
464 Ok(0)
465 }
466
467 fn non_mapped_size_in_bytes(
468 &self,
469 config: &str,
470 dtype: DType,
471 weight_pack_factor: usize,
472 _matformer_config: Option<&MatformerSliceConfig>,
473 ) -> Result<usize> {
474 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
475
476 let elems = {
477 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
478 let norm = cfg.hidden_size;
479 embed_tokens + norm
480 };
481 Ok(elems * dtype.size_in_bytes())
482 }
483
484 fn layer_sizes_in_bytes(
485 &self,
486 config: &str,
487 dtype: DType,
488 weight_pack_factor: usize,
489 _matformer_config: Option<&MatformerSliceConfig>,
490 ) -> Result<Vec<usize>> {
491 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
492
493 let per_layer_elems = {
494 let input_layernorm = cfg.hidden_size;
495 let post_attention_layernorm = cfg.hidden_size;
496
497 let size_in = cfg.hidden_size;
498 let size_q = cfg.head_dim * cfg.num_attention_heads;
499 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
500 let q_proj =
501 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
502 let k_proj =
503 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
504 let v_proj =
505 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
506 let o_proj =
507 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
508
509 let h_size = cfg.hidden_size;
510 let i_size = cfg.intermediate_size;
511 let gate_proj = h_size * i_size / weight_pack_factor;
512 let up_proj = h_size * i_size / weight_pack_factor;
513 let down_proj = i_size * h_size / weight_pack_factor;
514
515 input_layernorm
516 + post_attention_layernorm
517 + q_proj
518 + k_proj
519 + v_proj
520 + o_proj
521 + gate_proj
522 + up_proj
523 + down_proj
524 };
525 Ok(vec![
526 per_layer_elems * dtype.size_in_bytes();
527 cfg.num_hidden_layers
528 ])
529 }
530
531 fn num_layers(&self, config: &str) -> Result<usize> {
532 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
533
534 Ok(cfg.num_hidden_layers)
535 }
536
537 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
538 let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
539
540 let cfg = ModelConfigMetadata {
541 max_seq_len: cfg.max_position_embeddings,
542 num_layers: cfg.num_hidden_layers,
543 hidden_size: cfg.hidden_size,
544 num_kv_heads: cfg.num_key_value_heads,
545 num_attn_heads: cfg.num_attention_heads,
546 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
548 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
549 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
550 };
551
552 Ok(Box::new(cfg))
553 }
554
555 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
556 None }
558}
559
560pub struct Qwen3EmbeddingLoader;
564
565impl EmbeddingModelLoader for Qwen3EmbeddingLoader {
566 fn load(
567 &self,
568 config: &str,
569 vb: ShardedVarBuilder,
570 normal_loading_metadata: NormalLoadingMetadata,
571 attention_mechanism: AttentionImplementation,
572 ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
573 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
574
575 Ok(Box::new(Qwen3EmbeddingModel::new(
576 &cfg,
577 vb,
578 self.is_gptx(config)?,
579 normal_loading_metadata,
580 attention_mechanism,
581 )?))
582 }
583 fn has_causal_attention(&self, _: &str) -> Result<bool> {
584 Ok(true)
585 }
586 fn is_gptx(&self, _: &str) -> Result<bool> {
587 Ok(true)
588 }
589 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
590 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
591
592 Ok(Box::new(cfg))
593 }
594}
595
596impl IsqModelLoader for Qwen3EmbeddingLoader {
597 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
598 Ok(vec![
599 Regex::new(r"lm_head\.(weight|bias)$")?,
600 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
602 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
603 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
604 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
605 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
607 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
608 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
609 ])
610 }
611 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
612 self.isq_layer_regexes(config)
613 }
614}
615
616impl DeviceMappedModelLoader for Qwen3EmbeddingLoader {
617 fn mapped_max_act_size_elems(
618 &self,
619 config: &str,
620 params: &AutoDeviceMapParams,
621 ) -> Result<usize> {
622 let AutoDeviceMapParams::Text {
623 max_seq_len,
624 max_batch_size,
625 } = params
626 else {
627 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
628 };
629
630 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
631
632 Ok(
633 max_batch_size
634 * cfg.num_attention_heads
635 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
636 )
637 }
638 fn non_mapped_max_act_size_elems(
639 &self,
640 _config: &str,
641 _params: &AutoDeviceMapParams,
642 ) -> Result<usize> {
643 Ok(0)
644 }
645
646 fn non_mapped_size_in_bytes(
647 &self,
648 config: &str,
649 dtype: DType,
650 weight_pack_factor: usize,
651 _matformer_config: Option<&MatformerSliceConfig>,
652 ) -> Result<usize> {
653 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
654 let elems = {
655 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
656 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
658 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
659 } else {
660 0
661 };
662 let norm = cfg.hidden_size;
663 embed_tokens + lm_head + norm
664 };
665 Ok(elems * dtype.size_in_bytes())
666 }
667
668 fn layer_sizes_in_bytes(
669 &self,
670 config: &str,
671 dtype: DType,
672 weight_pack_factor: usize,
673 _matformer_config: Option<&MatformerSliceConfig>,
674 ) -> Result<Vec<usize>> {
675 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
676 let per_layer_elems = {
677 let input_layernorm = cfg.hidden_size;
678 let post_attention_layernorm = cfg.hidden_size;
679
680 let size_in = cfg.hidden_size;
681 let size_q = cfg.head_dim() * cfg.num_attention_heads;
682 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
683 let q_proj = size_in * size_q / weight_pack_factor + size_q;
684 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
685 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
686 let o_proj = size_q * size_in / weight_pack_factor;
687
688 let h_size = cfg.hidden_size;
689 let i_size = cfg.intermediate_size;
690 let gate_proj = h_size * i_size / weight_pack_factor;
691 let up_proj = h_size * i_size / weight_pack_factor;
692 let down_proj = i_size * h_size / weight_pack_factor;
693
694 let q_norm = cfg.head_dim();
695 let k_norm = cfg.head_dim();
696
697 input_layernorm
698 + post_attention_layernorm
699 + q_proj
700 + k_proj
701 + v_proj
702 + o_proj
703 + gate_proj
704 + up_proj
705 + down_proj
706 + q_norm
707 + k_norm
708 };
709 Ok(vec![
710 per_layer_elems * dtype.size_in_bytes();
711 cfg.num_hidden_layers
712 ])
713 }
714
715 fn num_layers(&self, config: &str) -> Result<usize> {
716 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
717 Ok(cfg.num_hidden_layers)
718 }
719
720 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
721 let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
722
723 let cfg = ModelConfigMetadata {
724 max_seq_len: cfg.max_position_embeddings,
725 num_layers: cfg.num_hidden_layers,
726 hidden_size: cfg.hidden_size,
727 num_kv_heads: cfg.num_key_value_heads,
728 num_attn_heads: cfg.num_attention_heads,
729 sliding_window: cfg.sliding_window,
730 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
731 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
732 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
733 };
734
735 Ok(Box::new(cfg))
736 }
737}