use candle_core::{DType, Module, Result, Tensor};
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
use std::sync::Arc;
use crate::layers::{embedding, RmsNorm, ScaledEmbedding};
use super::config::Gemma3nTextConfig;
pub struct Gemma3nMultimodalEmbedder {
pub(crate) embedding: ScaledEmbedding,
pub(crate) hard_embedding_norm: RmsNorm,
pub(crate) soft_embedding_norm: RmsNorm,
pub(crate) embedding_projection: Arc<dyn QuantMethod>,
pub(crate) embedding_post_projection_norm: RmsNorm,
vocab_offset: i64,
}
impl Gemma3nMultimodalEmbedder {
pub fn new(
cfg: &Gemma3nTextConfig,
multimodal_vocab_size: usize,
multimodal_hidden_size: usize,
vocab_offset: i64,
vb: ShardedVarBuilder,
) -> Result<Self> {
let embed_tokens = embedding(
multimodal_vocab_size,
multimodal_hidden_size,
vb.pp("embedding"),
&cfg.quantization_config,
)?;
let embedding = ScaledEmbedding::new((multimodal_hidden_size as f64).sqrt(), embed_tokens);
let hard_embedding_norm = RmsNorm::new_gemma_3n(
multimodal_hidden_size,
cfg.rms_norm_eps,
true, vb.pp("hard_embedding_norm"),
)?;
let soft_embedding_norm = RmsNorm::new_gemma_3n(
multimodal_hidden_size,
cfg.rms_norm_eps,
true, vb.pp("soft_embedding_norm"),
)?;
let embedding_projection = mistralrs_quant::linear_no_bias(
multimodal_hidden_size,
cfg.hidden_size,
&None,
vb.pp("embedding_projection"),
)?;
let embedding_post_projection_norm = RmsNorm::new_gemma_3n(
cfg.hidden_size,
cfg.rms_norm_eps,
false, vb.pp("embedding_post_projection_norm"),
)?;
Ok(Self {
embedding,
hard_embedding_norm,
soft_embedding_norm,
embedding_projection,
embedding_post_projection_norm,
vocab_offset,
})
}
pub fn forward_text(&self, input_ids: &Tensor) -> Result<Tensor> {
let adjusted_ids = if self.vocab_offset != 0 {
let adjusted = (input_ids.to_dtype(DType::F32)? - self.vocab_offset as f64)?;
adjusted.to_dtype(input_ids.dtype())?
} else {
input_ids.clone()
};
let embeddings = self.embedding.forward(&adjusted_ids)?;
let normalized = self.hard_embedding_norm.forward(&embeddings)?;
let projected = self
.embedding_projection
.forward_autocast(&normalized.unsqueeze(0)?)?
.squeeze(0)?;
self.embedding_post_projection_norm.forward(&projected)
}
pub fn forward_vision(&self, soft_features: &Tensor) -> Result<Tensor> {
let normalized = self.soft_embedding_norm.forward(soft_features)?;
let projected = self.embedding_projection.forward_autocast(&normalized)?;
self.embedding_post_projection_norm.forward(&projected)
}
}