mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
use candle_core::{Result, Tensor};
use candle_nn::Module;
use mistralrs_quant::ShardedVarBuilder;

use crate::{
    layers::{AvgPool2d, GemmaRmsNorm},
    utils::unvarbuilder::UnVarBuilder,
};

use super::config::Gemma3Config;

pub struct Gemma3MultiModalProjector {
    mm_input_projection_weight: Tensor,
    mm_soft_emb_norm: GemmaRmsNorm,
    patches_per_image: usize,
    avg_pool: AvgPool2d,
}

impl Gemma3MultiModalProjector {
    pub fn new(cfg: &Gemma3Config, vb: ShardedVarBuilder) -> Result<Self> {
        let Gemma3Config::WithVision {
            text_config,
            vision_config,
            image_token_index: _,
            mm_tokens_per_image,
        } = cfg
        else {
            unreachable!()
        };

        let mm_input_projection_weight = vb.get(
            (vision_config.hidden_size, text_config.hidden_size),
            "mm_input_projection_weight",
        )?;
        let mm_soft_emb_norm = GemmaRmsNorm::new(
            vision_config.hidden_size,
            vision_config.layer_norm_eps,
            vb.pp("mm_soft_emb_norm"),
        )?;

        let patches_per_image = vision_config.image_size / vision_config.patch_size;
        let tokens_per_side = mm_tokens_per_image.isqrt();
        let kernel_size = patches_per_image / tokens_per_side;
        let avg_pool = AvgPool2d::new(kernel_size, kernel_size);

        Ok(Self {
            mm_input_projection_weight,
            mm_soft_emb_norm,
            patches_per_image,
            avg_pool,
        })
    }

    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let (bs, _, seqlen) = xs.dims3()?;

        let mut reshaped_vision_outputs = xs.transpose(1, 2)?;
        reshaped_vision_outputs = reshaped_vision_outputs.reshape((
            bs,
            seqlen,
            self.patches_per_image,
            self.patches_per_image,
        ))?;
        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()?;

        let mut pooled_vision_outputs = self.avg_pool.forward(&reshaped_vision_outputs)?;
        pooled_vision_outputs = pooled_vision_outputs.flatten_from(2)?;
        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)?;

        let normed_vision_outputs = self.mm_soft_emb_norm.forward(&pooled_vision_outputs)?;

        normed_vision_outputs.broadcast_matmul(&self.mm_input_projection_weight)
    }

    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
        let uvb = UnVarBuilder::new();

        uvb.pp("mm_soft_emb_norm").add(&self.mm_soft_emb_norm);

        let mut tensors = uvb.to_safetensors();
        tensors.push((
            "mm_input_projection_weight.weight".to_string(),
            self.mm_input_projection_weight.clone(),
        ));
        tensors
    }
}