Skip to main content

hanzo_engine/pipeline/loaders/
embedding_loaders.rs

1use std::{
2    fmt::{self, Debug, Display},
3    path::PathBuf,
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{
9    attention::ATTENTION_CHUNK_SIZE,
10    embedding_models::{
11        embedding_gemma::{EmbeddingGemma, EmbeddingGemmaConfig},
12        qwen3_embedding::{Config as Qwen3EmbeddingConfig, Model as Qwen3EmbeddingModel},
13    },
14    matformer::MatformerSliceConfig,
15    pipeline::{loaders::auto_device_map::NonMappedSubModel, NormalLoadingMetadata},
16};
17
18use crate::{
19    amoe::AnyMoeBaseModelMixin,
20    device_map::DeviceMapper,
21    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
22    pipeline::{isq::IsqModelLoader, text_models_inputs_processor::FlashParams, IsqModel},
23    utils::varbuilder_utils::DeviceForLoadTensor,
24};
25use anyhow::Result;
26use hanzo_ml::{DType, Device, Tensor};
27use hanzo_quant::log::once_log_debug;
28
29use hanzo_quant::ShardedVarBuilder;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
35
36use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
37
38pub trait EmbeddingModel: IsqModel + AnyMoeBaseModelMixin {
39    #[allow(clippy::too_many_arguments)]
40    fn forward(&self, input_ids: &Tensor, flash_params: &FlashParams) -> hanzo_ml::Result<Tensor>;
41    fn device(&self) -> &Device;
42}
43
44pub trait EmbeddingModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
45    fn load(
46        &self,
47        config: &str,
48        vb: ShardedVarBuilder,
49        normal_loading_metadata: NormalLoadingMetadata,
50        attention_mechanism: AttentionImplementation,
51    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>>;
52    fn is_gptx(&self, config: &str) -> Result<bool>;
53    fn has_causal_attention(&self, config: &str) -> Result<bool>;
54    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
55    fn get_device_for_tensor(
56        &self,
57        config: &str,
58        _mapper: &dyn DeviceMapper,
59        loading_isq: bool,
60    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
61        if loading_isq {
62            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
63        } else {
64            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
65            let num_layers = self.model_config(config)?.num_layers();
66            let closure = move |name: String| {
67                if let Some(captures) = re.captures(&name) {
68                    captures
69                        .get(1)
70                        .and_then(|m| m.as_str().parse::<usize>().ok())
71                        .map(|l| l.min(num_layers))
72                        .map(DeviceForLoadTensor::Idx)
73                        .unwrap_or(DeviceForLoadTensor::Base)
74                } else {
75                    DeviceForLoadTensor::Base
76                }
77            };
78
79            Ok(Arc::new(closure))
80        }
81    }
82}
83
84#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
85#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
86/// The architecture to load the embedding model as.
87pub enum EmbeddingLoaderType {
88    #[serde(rename = "embeddinggemma")]
89    EmbeddingGemma,
90    #[serde(rename = "qwen3embedding")]
91    Qwen3Embedding,
92}
93
94// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
95impl EmbeddingLoaderType {
96    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
97        match name {
98            "Gemma3TextModel" => Ok(Self::EmbeddingGemma),
99            "Qwen3ForCausalLM" => Ok(Self::Qwen3Embedding),
100            other => anyhow::bail!(
101                "Unsupported Hugging Face Transformers model class `{other}`. Please raise an issue."
102            ),
103        }
104    }
105}
106
107impl FromStr for EmbeddingLoaderType {
108    type Err = String;
109    fn from_str(s: &str) -> Result<Self, Self::Err> {
110        match s {
111            "embeddinggemma" => Ok(Self::EmbeddingGemma),
112            "qwen3embedding" => Ok(Self::Qwen3Embedding),
113            a => Err(format!(
114                "Unknown architecture `{a}`. Possible architectures: `embeddinggemma`, `qwen3embedding`."
115            )),
116        }
117    }
118}
119
120impl Display for EmbeddingLoaderType {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Self::EmbeddingGemma => write!(f, "embeddinggemma"),
124            Self::Qwen3Embedding => write!(f, "qwen3embedding"),
125        }
126    }
127}
128
129#[derive(Clone, Debug, Deserialize)]
130pub enum EmbeddingModulePaths {
131    Transformer {
132        path: String,
133    },
134    Pooling {
135        path: String,
136        config: PathBuf,
137    },
138    Dense {
139        path: String,
140        config: PathBuf,
141        model: PathBuf,
142    },
143    Normalize {
144        path: String,
145    },
146}
147
148impl EmbeddingModulePaths {
149    pub fn serialize_modules(modules: &[EmbeddingModulePaths]) -> String {
150        #[derive(Serialize)]
151        struct OutputModule {
152            idx: usize,
153            name: String,
154            path: String,
155            #[serde(rename = "type")]
156            ty: String,
157        }
158
159        let mapped: Vec<OutputModule> = modules
160            .iter()
161            .enumerate()
162            .map(|(i, m)| {
163                let (path, ty) = match m {
164                    EmbeddingModulePaths::Transformer { path } => (
165                        path.clone(),
166                        "sentence_transformers.models.Transformer".to_string(),
167                    ),
168                    EmbeddingModulePaths::Pooling { path, .. } => (
169                        path.clone(),
170                        "sentence_transformers.models.Pooling".to_string(),
171                    ),
172                    EmbeddingModulePaths::Dense { path, .. } => (
173                        path.clone(),
174                        "sentence_transformers.models.Dense".to_string(),
175                    ),
176                    EmbeddingModulePaths::Normalize { path } => (
177                        path.clone(),
178                        "sentence_transformers.models.Normalize".to_string(),
179                    ),
180                };
181
182                OutputModule {
183                    idx: i,
184                    name: i.to_string(),
185                    path,
186                    ty,
187                }
188            })
189            .collect();
190
191        serde_json::to_string_pretty(&mapped).unwrap()
192    }
193}
194
195#[derive(Debug, Deserialize)]
196pub struct EmbeddingModule {
197    pub path: String,
198    #[serde(rename = "type", deserialize_with = "deserialize_module_type")]
199    pub ty: EmbeddingModuleType,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum EmbeddingModuleType {
204    Transformer,
205    Pooling,
206    Dense,
207    Normalize,
208}
209
210fn deserialize_module_type<'de, D>(deserializer: D) -> Result<EmbeddingModuleType, D::Error>
211where
212    D: Deserializer<'de>,
213{
214    struct ModuleTypeVisitor;
215
216    impl<'de> Visitor<'de> for ModuleTypeVisitor {
217        type Value = EmbeddingModuleType;
218
219        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
220            f.write_str("a sentence-transformers module type string")
221        }
222
223        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
224        where
225            E: serde::de::Error,
226        {
227            // Accept fully-qualified ("sentence_transformers.models.X") or just "X".
228            let last = v.rsplit('.').next().unwrap_or(v).to_ascii_lowercase();
229            match last.as_str() {
230                "transformer" => Ok(EmbeddingModuleType::Transformer),
231                "pooling" => Ok(EmbeddingModuleType::Pooling),
232                "dense" => Ok(EmbeddingModuleType::Dense),
233                "normalize" => Ok(EmbeddingModuleType::Normalize),
234                _ => Err(E::invalid_value(
235                    serde::de::Unexpected::Str(v),
236                    &"Transformer/Pooling/Dense/Normalize",
237                )),
238            }
239        }
240    }
241
242    deserializer.deserialize_str(ModuleTypeVisitor)
243}
244
245macro_rules! bias_if {
246    ($cond:expr, $size:expr) => {
247        if $cond {
248            $size
249        } else {
250            0
251        }
252    };
253}
254
255/// Load a model based on the Hugging Face Transformers -CausalLM model class
256pub struct AutoEmbeddingLoader;
257
258#[derive(Deserialize)]
259struct AutoEmbeddingLoaderConfig {
260    architectures: Vec<String>,
261}
262
263impl AutoEmbeddingLoader {
264    fn get_loader(config: &str) -> Result<Box<dyn EmbeddingModelLoader>> {
265        let auto_cfg: AutoEmbeddingLoaderConfig = serde_json::from_str(config)?;
266        if auto_cfg.architectures.len() != 1 {
267            anyhow::bail!("Expected to have one name for `architectures` config field.")
268        }
269
270        let name = &auto_cfg.architectures[0];
271
272        let tp = EmbeddingLoaderType::from_causal_lm_name(name)?;
273
274        once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
275
276        match tp {
277            EmbeddingLoaderType::EmbeddingGemma => Ok(Box::new(EmbeddingGemmaLoader)),
278            EmbeddingLoaderType::Qwen3Embedding => Ok(Box::new(Qwen3EmbeddingLoader)),
279        }
280    }
281}
282
283impl EmbeddingModelLoader for AutoEmbeddingLoader {
284    fn load(
285        &self,
286        config: &str,
287        vb: ShardedVarBuilder,
288        normal_loading_metadata: NormalLoadingMetadata,
289        attention_mechanism: AttentionImplementation,
290    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
291        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
292    }
293    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
294        Self::get_loader(config)?.get_config_repr(config)
295    }
296    fn has_causal_attention(&self, config: &str) -> Result<bool> {
297        Self::get_loader(config)?.has_causal_attention(config)
298    }
299    fn is_gptx(&self, config: &str) -> Result<bool> {
300        Self::get_loader(config)?.is_gptx(config)
301    }
302}
303
304impl IsqModelLoader for AutoEmbeddingLoader {
305    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
306        Self::get_loader(config)?.immediate_isq_predicates(config)
307    }
308    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
309        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
310    }
311    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
312        Self::get_loader(config)?.isq_layer_regexes(config)
313    }
314    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
315        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
316    }
317}
318
319impl DeviceMappedModelLoader for AutoEmbeddingLoader {
320    fn non_mapped_size_in_bytes(
321        &self,
322        config: &str,
323        dtype: DType,
324        weight_pack_factor: usize,
325        _matformer_config: Option<&MatformerSliceConfig>,
326    ) -> Result<usize> {
327        Self::get_loader(config)?.non_mapped_size_in_bytes(
328            config,
329            dtype,
330            weight_pack_factor,
331            _matformer_config,
332        )
333    }
334    fn num_layers(&self, config: &str) -> Result<usize> {
335        Self::get_loader(config)?.num_layers(config)
336    }
337    fn layer_sizes_in_bytes(
338        &self,
339        config: &str,
340        dtype: DType,
341        weight_pack_factor: usize,
342        _matformer_config: Option<&MatformerSliceConfig>,
343    ) -> Result<Vec<usize>> {
344        Self::get_loader(config)?.layer_sizes_in_bytes(
345            config,
346            dtype,
347            weight_pack_factor,
348            _matformer_config,
349        )
350    }
351    fn mapped_max_act_size_elems(
352        &self,
353        config: &str,
354        params: &super::AutoDeviceMapParams,
355    ) -> Result<usize> {
356        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
357    }
358    fn non_mapped_max_act_size_elems(
359        &self,
360        _config: &str,
361        _params: &AutoDeviceMapParams,
362    ) -> Result<usize> {
363        Ok(0)
364    }
365    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
366        Self::get_loader(config)?.model_config(config)
367    }
368}
369
370/// [`EmbeddingModelLoader`] for an Embedding Gemma model.
371///
372/// [`EmbeddingModelLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.EmbeddingModelLoader.html
373pub struct EmbeddingGemmaLoader;
374
375impl EmbeddingModelLoader for EmbeddingGemmaLoader {
376    fn load(
377        &self,
378        config: &str,
379        vb: ShardedVarBuilder,
380        normal_loading_metadata: NormalLoadingMetadata,
381        attention_mechanism: AttentionImplementation,
382    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
383        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
384
385        Ok(Box::new(EmbeddingGemma::new(
386            &cfg,
387            vb,
388            self.is_gptx(config)?,
389            normal_loading_metadata,
390            attention_mechanism,
391        )?))
392    }
393    fn is_gptx(&self, _: &str) -> Result<bool> {
394        Ok(true)
395    }
396    fn has_causal_attention(&self, _: &str) -> Result<bool> {
397        Ok(false)
398    }
399    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
400        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
401        Ok(Box::new(cfg))
402    }
403}
404
405impl IsqModelLoader for EmbeddingGemmaLoader {
406    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
407        Ok(vec![
408            Regex::new(r"lm_head\.(weight|bias)$")?,
409            // Attention
410            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
411            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
412            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
413            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
414            // MLP
415            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
416            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
417            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
418        ])
419    }
420    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
421        Ok(vec![
422            Regex::new(r"lm_head\.(weight|bias)$")?,
423            // Attention
424            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
425            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
426            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
427            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
428            // MLP
429            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
430            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
431            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
432        ])
433    }
434}
435
436impl DeviceMappedModelLoader for EmbeddingGemmaLoader {
437    fn mapped_max_act_size_elems(
438        &self,
439        config: &str,
440        params: &AutoDeviceMapParams,
441    ) -> Result<usize> {
442        let AutoDeviceMapParams::Text {
443            max_seq_len,
444            max_batch_size,
445        } = params
446        else {
447            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
448        };
449
450        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
451
452        Ok(
453            max_batch_size
454                * cfg.num_attention_heads
455                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
456        )
457    }
458
459    fn non_mapped_max_act_size_elems(
460        &self,
461        _config: &str,
462        _params: &AutoDeviceMapParams,
463    ) -> Result<usize> {
464        Ok(0)
465    }
466
467    fn non_mapped_size_in_bytes(
468        &self,
469        config: &str,
470        dtype: DType,
471        weight_pack_factor: usize,
472        _matformer_config: Option<&MatformerSliceConfig>,
473    ) -> Result<usize> {
474        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
475
476        let elems = {
477            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
478            let norm = cfg.hidden_size;
479            embed_tokens + norm
480        };
481        Ok(elems * dtype.size_in_bytes())
482    }
483
484    fn layer_sizes_in_bytes(
485        &self,
486        config: &str,
487        dtype: DType,
488        weight_pack_factor: usize,
489        _matformer_config: Option<&MatformerSliceConfig>,
490    ) -> Result<Vec<usize>> {
491        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
492
493        let per_layer_elems = {
494            let input_layernorm = cfg.hidden_size;
495            let post_attention_layernorm = cfg.hidden_size;
496
497            let size_in = cfg.hidden_size;
498            let size_q = cfg.head_dim * cfg.num_attention_heads;
499            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
500            let q_proj =
501                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
502            let k_proj =
503                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
504            let v_proj =
505                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
506            let o_proj =
507                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
508
509            let h_size = cfg.hidden_size;
510            let i_size = cfg.intermediate_size;
511            let gate_proj = h_size * i_size / weight_pack_factor;
512            let up_proj = h_size * i_size / weight_pack_factor;
513            let down_proj = i_size * h_size / weight_pack_factor;
514
515            input_layernorm
516                + post_attention_layernorm
517                + q_proj
518                + k_proj
519                + v_proj
520                + o_proj
521                + gate_proj
522                + up_proj
523                + down_proj
524        };
525        Ok(vec![
526            per_layer_elems * dtype.size_in_bytes();
527            cfg.num_hidden_layers
528        ])
529    }
530
531    fn num_layers(&self, config: &str) -> Result<usize> {
532        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
533
534        Ok(cfg.num_hidden_layers)
535    }
536
537    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
538        let cfg: EmbeddingGemmaConfig = serde_json::from_str(config)?;
539
540        let cfg = ModelConfigMetadata {
541            max_seq_len: cfg.max_position_embeddings,
542            num_layers: cfg.num_hidden_layers,
543            hidden_size: cfg.hidden_size,
544            num_kv_heads: cfg.num_key_value_heads,
545            num_attn_heads: cfg.num_attention_heads,
546            sliding_window: None, // None to be more forgiving, some do not
547            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
548            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
549            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
550        };
551
552        Ok(Box::new(cfg))
553    }
554
555    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
556        None // todo
557    }
558}
559
560/// [`EmbeddingModelLoader`] for a Qwen 3 model.
561///
562/// [`EmbeddingModelLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.EmbeddingModelLoader.html
563pub struct Qwen3EmbeddingLoader;
564
565impl EmbeddingModelLoader for Qwen3EmbeddingLoader {
566    fn load(
567        &self,
568        config: &str,
569        vb: ShardedVarBuilder,
570        normal_loading_metadata: NormalLoadingMetadata,
571        attention_mechanism: AttentionImplementation,
572    ) -> Result<Box<dyn EmbeddingModel + Send + Sync>> {
573        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
574
575        Ok(Box::new(Qwen3EmbeddingModel::new(
576            &cfg,
577            vb,
578            self.is_gptx(config)?,
579            normal_loading_metadata,
580            attention_mechanism,
581        )?))
582    }
583    fn has_causal_attention(&self, _: &str) -> Result<bool> {
584        Ok(true)
585    }
586    fn is_gptx(&self, _: &str) -> Result<bool> {
587        Ok(true)
588    }
589    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
590        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
591
592        Ok(Box::new(cfg))
593    }
594}
595
596impl IsqModelLoader for Qwen3EmbeddingLoader {
597    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
598        Ok(vec![
599            Regex::new(r"lm_head\.(weight|bias)$")?,
600            // Attention
601            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
602            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
603            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
604            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
605            // MLP
606            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
607            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
608            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
609        ])
610    }
611    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
612        self.isq_layer_regexes(config)
613    }
614}
615
616impl DeviceMappedModelLoader for Qwen3EmbeddingLoader {
617    fn mapped_max_act_size_elems(
618        &self,
619        config: &str,
620        params: &AutoDeviceMapParams,
621    ) -> Result<usize> {
622        let AutoDeviceMapParams::Text {
623            max_seq_len,
624            max_batch_size,
625        } = params
626        else {
627            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
628        };
629
630        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
631
632        Ok(
633            max_batch_size
634                * cfg.num_attention_heads
635                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
636        )
637    }
638    fn non_mapped_max_act_size_elems(
639        &self,
640        _config: &str,
641        _params: &AutoDeviceMapParams,
642    ) -> Result<usize> {
643        Ok(0)
644    }
645
646    fn non_mapped_size_in_bytes(
647        &self,
648        config: &str,
649        dtype: DType,
650        weight_pack_factor: usize,
651        _matformer_config: Option<&MatformerSliceConfig>,
652    ) -> Result<usize> {
653        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
654        let elems = {
655            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
656            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
657            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
658                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
659            } else {
660                0
661            };
662            let norm = cfg.hidden_size;
663            embed_tokens + lm_head + norm
664        };
665        Ok(elems * dtype.size_in_bytes())
666    }
667
668    fn layer_sizes_in_bytes(
669        &self,
670        config: &str,
671        dtype: DType,
672        weight_pack_factor: usize,
673        _matformer_config: Option<&MatformerSliceConfig>,
674    ) -> Result<Vec<usize>> {
675        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
676        let per_layer_elems = {
677            let input_layernorm = cfg.hidden_size;
678            let post_attention_layernorm = cfg.hidden_size;
679
680            let size_in = cfg.hidden_size;
681            let size_q = cfg.head_dim() * cfg.num_attention_heads;
682            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
683            let q_proj = size_in * size_q / weight_pack_factor + size_q;
684            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
685            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
686            let o_proj = size_q * size_in / weight_pack_factor;
687
688            let h_size = cfg.hidden_size;
689            let i_size = cfg.intermediate_size;
690            let gate_proj = h_size * i_size / weight_pack_factor;
691            let up_proj = h_size * i_size / weight_pack_factor;
692            let down_proj = i_size * h_size / weight_pack_factor;
693
694            let q_norm = cfg.head_dim();
695            let k_norm = cfg.head_dim();
696
697            input_layernorm
698                + post_attention_layernorm
699                + q_proj
700                + k_proj
701                + v_proj
702                + o_proj
703                + gate_proj
704                + up_proj
705                + down_proj
706                + q_norm
707                + k_norm
708        };
709        Ok(vec![
710            per_layer_elems * dtype.size_in_bytes();
711            cfg.num_hidden_layers
712        ])
713    }
714
715    fn num_layers(&self, config: &str) -> Result<usize> {
716        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
717        Ok(cfg.num_hidden_layers)
718    }
719
720    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
721        let cfg: Qwen3EmbeddingConfig = serde_json::from_str(config)?;
722
723        let cfg = ModelConfigMetadata {
724            max_seq_len: cfg.max_position_embeddings,
725            num_layers: cfg.num_hidden_layers,
726            hidden_size: cfg.hidden_size,
727            num_kv_heads: cfg.num_key_value_heads,
728            num_attn_heads: cfg.num_attention_heads,
729            sliding_window: cfg.sliding_window,
730            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
731            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
732            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
733        };
734
735        Ok(Box::new(cfg))
736    }
737}