#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use candle_core::{Module, Result, Tensor};
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
use std::sync::Arc;
use crate::{layers::RmsNorm, utils::unvarbuilder::UnVarBuilder};
pub struct Gemma4MultimodalEmbedder {
pub(crate) embedding_projection: Arc<dyn QuantMethod>,
pub(crate) embedding_pre_projection_norm: RmsNorm,
}
impl Gemma4MultimodalEmbedder {
pub fn new(
multimodal_hidden_size: usize,
text_hidden_size: usize,
eps: f64,
vb: ShardedVarBuilder,
) -> Result<Self> {
let embedding_projection = mistralrs_quant::linear_no_bias(
multimodal_hidden_size,
text_hidden_size,
&None,
vb.pp("embedding_projection"),
)?;
let embedding_pre_projection_norm = RmsNorm::new_gemma_3n(
multimodal_hidden_size,
eps,
false,
vb.pp("embedding_pre_projection_norm"),
)?;
Ok(Self {
embedding_projection,
embedding_pre_projection_norm,
})
}
pub fn forward(&self, soft_features: &Tensor) -> Result<Tensor> {
let mut normed = soft_features.clone();
let norm_dtype = self.embedding_pre_projection_norm.weight().dtype();
if normed.dtype() != norm_dtype {
normed = normed.to_dtype(norm_dtype)?;
}
let normed = self.embedding_pre_projection_norm.forward(&normed)?;
self.embedding_projection.forward_autocast(&normed)
}
pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
uvb.pp("embedding_projection")
.add(&self.embedding_projection);
uvb.pp("embedding_pre_projection_norm")
.add(&self.embedding_pre_projection_norm);
uvb.to_safetensors()
}
}