1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use hanzo_ml::{DType, Module, Result, Tensor};
use hanzo_quant::{QuantMethod, ShardedVarBuilder};
use std::sync::Arc;
use crate::layers::{embedding, RmsNorm, ScaledEmbedding};
use super::config::Gemma3nTextConfig;
/// Multimodal embedder for Gemma3n that handles both text tokens and vision embeddings
pub struct Gemma3nMultimodalEmbedder {
/// Embedding layer for vocabulary tokens
pub(crate) embedding: ScaledEmbedding,
/// RMS normalization for hard embeddings (text tokens)
pub(crate) hard_embedding_norm: RmsNorm,
/// RMS normalization for soft embeddings (vision features)
pub(crate) soft_embedding_norm: RmsNorm,
/// Linear projection from multimodal hidden size to text hidden size
pub(crate) embedding_projection: Arc<dyn QuantMethod>,
/// Post-projection normalization (without scale)
pub(crate) embedding_post_projection_norm: RmsNorm,
/// The vocabulary offset to subtract from input IDs
vocab_offset: i64,
}
impl Gemma3nMultimodalEmbedder {
pub fn new(
cfg: &Gemma3nTextConfig,
multimodal_vocab_size: usize,
multimodal_hidden_size: usize,
vocab_offset: i64,
vb: ShardedVarBuilder,
) -> Result<Self> {
// Create the embedding layer with proper VarBuilder prefix
// The embedding layer uses the multimodal vocab size (e.g., 128 for vision)
let embed_tokens = embedding(
multimodal_vocab_size,
multimodal_hidden_size,
vb.pp("embedding"),
&cfg.quantization_config,
)?;
// Scale the embeddings by sqrt(multimodal_hidden_size)
let embedding = ScaledEmbedding::new((multimodal_hidden_size as f64).sqrt(), embed_tokens);
// Create normalization layers with proper prefixes
let hard_embedding_norm = RmsNorm::new_gemma_3n(
multimodal_hidden_size,
cfg.rms_norm_eps,
true, // with_scale = true
vb.pp("hard_embedding_norm"),
)?;
let soft_embedding_norm = RmsNorm::new_gemma_3n(
multimodal_hidden_size,
cfg.rms_norm_eps,
true, // with_scale = true
vb.pp("soft_embedding_norm"),
)?;
// Linear projection from multimodal to text hidden size
let embedding_projection = hanzo_quant::linear_no_bias(
multimodal_hidden_size,
cfg.hidden_size,
&None,
vb.pp("embedding_projection"),
)?;
// Post-projection normalization without scale
let embedding_post_projection_norm = RmsNorm::new_gemma_3n(
cfg.hidden_size,
cfg.rms_norm_eps,
false, // with_scale = false
vb.pp("embedding_post_projection_norm"),
)?;
Ok(Self {
embedding,
hard_embedding_norm,
soft_embedding_norm,
embedding_projection,
embedding_post_projection_norm,
vocab_offset,
})
}
/// Forward pass for text input IDs
pub fn forward_text(&self, input_ids: &Tensor) -> Result<Tensor> {
// Subtract vocab_offset from input_ids
let adjusted_ids = if self.vocab_offset != 0 {
let adjusted = (input_ids.to_dtype(DType::F32)? - self.vocab_offset as f64)?;
adjusted.to_dtype(input_ids.dtype())?
} else {
input_ids.clone()
};
// Embed the tokens
let embeddings = self.embedding.forward(&adjusted_ids)?;
// Apply hard embedding normalization
let normalized = self.hard_embedding_norm.forward(&embeddings)?;
// Project to text hidden size
let projected = self
.embedding_projection
.forward(&normalized.unsqueeze(0)?)?
.squeeze(0)?;
// Apply post-projection normalization
self.embedding_post_projection_norm.forward(&projected)
}
/// Forward pass for vision embeddings (soft features)
pub fn forward_vision(&self, soft_features: &Tensor) -> Result<Tensor> {
// Apply soft embedding normalization
let normalized = self.soft_embedding_norm.forward(soft_features)?;
// Project to text hidden size
let projected = self.embedding_projection.forward(&normalized)?;
// Apply post-projection normalization
self.embedding_post_projection_norm.forward(&projected)
}
}