pub mod audio;
pub mod config;
pub mod multimodal_embedding;
pub mod text;
pub mod vision;
use hanzo_ml::{DType, Result, Tensor, D};
use config::Gemma4Config;
use multimodal_embedding::MultimodalEmbedder;
use text::TextModel;
use vision::VisionTower;
pub use audio::AudioModel;
pub use config::{Gemma4AudioConfig, Gemma4TextConfig, Gemma4VisionConfig};
pub struct Model {
pub language_model: TextModel,
pub vision_tower: VisionTower,
pub embed_vision: MultimodalEmbedder,
pub audio_tower: Option<AudioModel>,
pub embed_audio: Option<MultimodalEmbedder>,
pub cfg: Gemma4Config,
}
impl Model {
pub fn new(cfg: &Gemma4Config, vb: hanzo_nn::VarBuilder) -> Result<Self> {
let vb = vb.pp("model");
let vision_tower = VisionTower::new(&cfg.vision_config, vb.pp("vision_tower"))?;
let vis_hidden = cfg.vision_config.hidden_size;
let text_hidden = cfg.text_config.hidden_size;
let embed_vision = MultimodalEmbedder::new(
vis_hidden,
text_hidden,
cfg.vision_config.rms_norm_eps,
vb.pp("embed_vision"),
)?;
let (audio_tower, embed_audio) = if let Some(ref audio_cfg) = cfg.audio_config {
let tower = AudioModel::new(audio_cfg, vb.pp("audio_tower"))?;
let audio_hidden = audio_cfg.output_proj_dims.unwrap_or(audio_cfg.hidden_size);
let embed = MultimodalEmbedder::new(
audio_hidden,
text_hidden,
audio_cfg.rms_norm_eps,
vb.pp("embed_audio"),
)?;
(Some(tower), Some(embed))
} else {
(None, None)
};
let language_model = TextModel::new(&cfg.text_config, vb.pp("language_model"))?;
Ok(Self {
language_model,
vision_tower,
embed_vision,
audio_tower,
embed_audio,
cfg: cfg.clone(),
})
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
self.language_model.forward(input_ids, seqlen_offset)
}
#[allow(clippy::too_many_arguments)]
pub fn forward_multimodal(
&mut self,
input_ids: &Tensor,
pixel_values: Option<&[Tensor]>,
audio_mel: Option<&Tensor>,
audio_mel_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
if let Some(pixel_values) = pixel_values {
let image_mask = input_ids
.to_dtype(DType::F32)?
.eq(self.cfg.image_token_id as f64)?;
let vision_features = self.vision_tower.forward(pixel_values)?;
let image_embeds = self
.embed_vision
.forward(&vision_features)?
.to_dtype(input_embeds.dtype())?;
let image_embeds_flat = image_embeds.squeeze(0)?;
let mask_expanded = image_mask
.unsqueeze(D::Minus1)?
.broadcast_as(input_embeds.shape())?
.to_dtype(input_embeds.dtype())?;
let image_embeds_broadcast = broadcast_embed_to_mask(&image_embeds_flat, &image_mask)?;
input_embeds = ((mask_expanded.clone() * image_embeds_broadcast)?
+ ((1.0 - mask_expanded)? * input_embeds)?)?;
}
if let (
Some(audio_mel),
Some(audio_mel_mask),
Some(ref audio_tower),
Some(ref embed_audio),
) = (
audio_mel,
audio_mel_mask,
&self.audio_tower,
&self.embed_audio,
) {
let audio_mask = input_ids
.to_dtype(DType::F32)?
.eq(self.cfg.audio_token_id as f64)?;
let (audio_features, enc_mask) = audio_tower.forward(audio_mel, audio_mel_mask)?;
let valid = enc_mask.eq(0.0)?;
let batch = audio_features.dim(0)?;
let mut all_feats = Vec::new();
for b in 0..batch {
let valid_b = valid.get(b)?;
let valid_sum = valid_b
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()? as usize;
if valid_sum > 0 {
all_feats.push(audio_features.get(b)?.narrow(0, 0, valid_sum)?);
}
}
if !all_feats.is_empty() {
let audio_feats = Tensor::cat(&all_feats, 0)?.unsqueeze(0)?;
let audio_embeds = embed_audio
.forward(&audio_feats)?
.to_dtype(input_embeds.dtype())?;
let audio_embeds_flat = audio_embeds.squeeze(0)?;
let mask_expanded = audio_mask
.unsqueeze(D::Minus1)?
.broadcast_as(input_embeds.shape())?
.to_dtype(input_embeds.dtype())?;
let audio_embeds_broadcast =
broadcast_embed_to_mask(&audio_embeds_flat, &audio_mask)?;
input_embeds = ((mask_expanded.clone() * audio_embeds_broadcast)?
+ ((1.0 - mask_expanded)? * input_embeds)?)?;
}
}
self.language_model
.forward_embeds(&input_embeds, seqlen_offset, b_size, seq_len)
}
pub fn clear_kv_cache(&mut self) {
self.language_model.clear_kv_cache()
}
}
fn broadcast_embed_to_mask(embeds: &Tensor, mask: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = mask.dims2()?;
let hidden = embeds.dim(D::Minus1)?;
let mask_f32 = mask.to_dtype(DType::F32)?;
let zeros = Tensor::zeros((b_sz, seq_len, hidden), embeds.dtype(), embeds.device())?;
if b_sz == 1 {
let num_tokens = mask_f32.sum_all()?.to_scalar::<f32>()? as usize;
if num_tokens == 0 {
return Ok(zeros);
}
let embed_len = embeds.dim(0)?;
if embed_len >= seq_len {
return embeds.narrow(0, 0, seq_len)?.unsqueeze(0);
}
let padding = Tensor::zeros(
(seq_len - embed_len, hidden),
embeds.dtype(),
embeds.device(),
)?;
let padded = Tensor::cat(&[embeds, &padding], 0)?;
return padded.unsqueeze(0);
}
Ok(zeros)
}