#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{any::Any, sync::Arc};
use candle_core::{Device, Result, Tensor};
use image::{DynamicImage, GenericImageView};
use mistralrs_vision::{ApplyTransforms, Rescale, ToTensorNoNorm, Transforms};
use tokenizers::Tokenizer;
use crate::{
device_map::DeviceMapper,
pipeline::{
text_models_inputs_processor::{
self, get_completion_input, get_prompt_input, PagedAttentionMeta,
},
InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
},
sequence::{build_mm_features_from_ranges, find_image_placeholder_ranges, Sequence},
vision_models::gemma4::audio_processing::AudioProcessor,
vision_models::{
image_processor::{ImagePreProcessor, PreprocessedImages},
preprocessor_config::{PreProcessorConfig, ToFilter},
processor_config::ProcessorConfig,
ModelInputs,
},
};
use super::Gemma4SpecificArgs;
const IMAGE_TOKEN: &str = "<|image|>";
const BOI_TOKEN: &str = "<|image>";
const EOI_TOKEN: &str = "<image|>";
pub const IMAGE_TOKEN_ID: u32 = 258880;
const AUDIO_TOKEN: &str = "<|audio|>";
const BOA_TOKEN: &str = "<|audio>";
const EOA_TOKEN: &str = "<audio|>";
pub const AUDIO_TOKEN_ID: u32 = 258881;
const VIDEO_TOKEN: &str = "<|video|>";
pub const VIDEO_TOKEN_ID: u32 = 258884;
pub struct Gemma4Processor {
patch_size: usize,
pooling_kernel_size: usize,
default_output_length: usize,
max_patches: usize,
audio_seq_length: usize,
video_max_soft_tokens: usize,
supports_images: bool,
supports_audio: bool,
}
impl Gemma4Processor {
pub fn new(
processor_config: ProcessorConfig,
patch_size: usize,
pooling_kernel_size: usize,
default_output_length: usize,
supports_images: bool,
supports_audio: bool,
) -> Self {
let max_patches = default_output_length * pooling_kernel_size * pooling_kernel_size;
let audio_seq_length = processor_config.audio_seq_length.unwrap_or(750);
let video_max_soft_tokens = processor_config.video_max_soft_tokens.unwrap_or(70);
Self {
patch_size,
pooling_kernel_size,
default_output_length,
max_patches,
audio_seq_length,
video_max_soft_tokens,
supports_images,
supports_audio,
}
}
}
impl Processor for Gemma4Processor {
fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
let video_max_patches =
self.video_max_soft_tokens * self.pooling_kernel_size * self.pooling_kernel_size;
Arc::new(Gemma4ImageProcessor {
patch_size: self.patch_size,
pooling_kernel_size: self.pooling_kernel_size,
default_output_length: self.default_output_length,
max_patches: self.max_patches,
audio_seq_length: self.audio_seq_length,
video_max_soft_tokens: self.video_max_soft_tokens,
video_max_patches,
supports_images: self.supports_images,
supports_audio: self.supports_audio,
})
}
fn get_special_tokens(&self) -> &[&'static str] {
&[
IMAGE_TOKEN,
BOI_TOKEN,
EOI_TOKEN,
AUDIO_TOKEN,
BOA_TOKEN,
EOA_TOKEN,
VIDEO_TOKEN,
]
}
fn template_action(&self) -> MessagesAction {
MessagesAction::Keep
}
}
#[allow(dead_code)]
struct Gemma4ImageProcessor {
patch_size: usize,
pooling_kernel_size: usize,
default_output_length: usize,
max_patches: usize,
audio_seq_length: usize,
video_max_soft_tokens: usize,
video_max_patches: usize,
supports_images: bool,
supports_audio: bool,
}
impl Gemma4ImageProcessor {
fn output_tokens_for_size(&self, new_h: usize, new_w: usize) -> usize {
let ph = new_h / self.patch_size;
let pw = new_w / self.patch_size;
let pool_area = self.pooling_kernel_size * self.pooling_kernel_size;
(ph * pw) / pool_area
}
fn compute_resize_dims(&self, orig_h: usize, orig_w: usize) -> Result<(usize, usize)> {
if orig_h == 0 || orig_w == 0 {
candle_core::bail!(
"Gemma4 image resize: input dimensions must be non-zero, got {orig_h}x{orig_w}"
);
}
let target_px = self.max_patches * self.patch_size * self.patch_size;
let grid_unit = self.pooling_kernel_size * self.patch_size; let pool_area = self.pooling_kernel_size * self.pooling_kernel_size;
let max_side_length = (self.max_patches / pool_area) * grid_unit;
let factor = (target_px as f64 / (orig_h as f64 * orig_w as f64)).sqrt();
let ideal_h = orig_h as f64 * factor;
let ideal_w = orig_w as f64 * factor;
let mut new_h = (ideal_h / grid_unit as f64).floor() as usize * grid_unit;
let mut new_w = (ideal_w / grid_unit as f64).floor() as usize * grid_unit;
if new_h == 0 && new_w == 0 {
candle_core::bail!(
"Gemma4 image resize: both dimensions round to 0 for input {orig_h}x{orig_w}"
);
}
if new_h == 0 {
new_h = grid_unit;
new_w = ((orig_w / orig_h) * grid_unit).min(max_side_length);
new_w = (new_w / grid_unit).max(1) * grid_unit;
} else if new_w == 0 {
new_w = grid_unit;
new_h = ((orig_h / orig_w) * grid_unit).min(max_side_length);
new_h = (new_h / grid_unit).max(1) * grid_unit;
}
if new_h * new_w > target_px {
candle_core::bail!(
"Gemma4 image resize: {new_h}x{new_w} = {} pixels exceeds patch budget of {target_px} \
for input {orig_h}x{orig_w}",
new_h * new_w
);
}
Ok((new_h, new_w))
}
fn build_image_sequence(&self, num_tokens: usize) -> String {
let image_tokens = vec![IMAGE_TOKEN.to_string(); num_tokens].join("");
format!("{BOI_TOKEN}{image_tokens}{EOI_TOKEN}")
}
fn compute_audio_num_tokens(&self, num_mel_frames: usize) -> usize {
if num_mel_frames == 0 {
return 0;
}
let mut t = num_mel_frames;
for _ in 0..2 {
t = (t + 2 - 3) / 2 + 1;
}
t.min(self.audio_seq_length)
}
fn build_audio_sequence(&self, num_tokens: usize) -> String {
let audio_tokens = vec![AUDIO_TOKEN.to_string(); num_tokens].join("");
format!("{BOA_TOKEN}{audio_tokens}{EOA_TOKEN}")
}
fn compute_video_resize_dims(&self, orig_h: usize, orig_w: usize) -> Result<(usize, usize)> {
if orig_h == 0 || orig_w == 0 {
candle_core::bail!(
"Gemma4 video resize: input dimensions must be non-zero, got {orig_h}x{orig_w}"
);
}
let target_px = self.video_max_patches * self.patch_size * self.patch_size;
let grid_unit = self.pooling_kernel_size * self.patch_size;
let pool_area = self.pooling_kernel_size * self.pooling_kernel_size;
let max_side_length = (self.video_max_patches / pool_area) * grid_unit;
let factor = (target_px as f64 / (orig_h as f64 * orig_w as f64)).sqrt();
let ideal_h = orig_h as f64 * factor;
let ideal_w = orig_w as f64 * factor;
let mut new_h = (ideal_h / grid_unit as f64).floor() as usize * grid_unit;
let mut new_w = (ideal_w / grid_unit as f64).floor() as usize * grid_unit;
if new_h == 0 && new_w == 0 {
candle_core::bail!(
"Gemma4 video resize: both dimensions round to 0 for input {orig_h}x{orig_w}"
);
}
if new_h == 0 {
new_h = grid_unit;
new_w = ((orig_w / orig_h) * grid_unit).min(max_side_length);
new_w = (new_w / grid_unit).max(1) * grid_unit;
} else if new_w == 0 {
new_w = grid_unit;
new_h = ((orig_h / orig_w) * grid_unit).min(max_side_length);
new_h = (new_h / grid_unit).max(1) * grid_unit;
}
if new_h * new_w > target_px {
candle_core::bail!(
"Gemma4 video resize: {new_h}x{new_w} = {} pixels exceeds patch budget of {target_px} \
for input {orig_h}x{orig_w}",
new_h * new_w
);
}
Ok((new_h, new_w))
}
fn video_tokens_for_size(&self, new_h: usize, new_w: usize) -> usize {
let ph = new_h / self.patch_size;
let pw = new_w / self.patch_size;
let pool_area = self.pooling_kernel_size * self.pooling_kernel_size;
(ph * pw) / pool_area
}
fn build_video_sequence(&self, timestamps: &[String], tokens_per_frame: usize) -> String {
let video_tokens = vec![VIDEO_TOKEN.to_string(); tokens_per_frame].join("");
timestamps
.iter()
.map(|ts| format!("{ts} {BOI_TOKEN}{video_tokens}{EOI_TOKEN}"))
.collect::<Vec<_>>()
.join(" ")
}
}
fn cached_tokens_for_ranges(prefix_len: usize, ranges: &[(usize, usize)]) -> Vec<usize> {
ranges
.iter()
.map(|&(offset, length)| prefix_len.saturating_sub(offset).min(length))
.collect()
}
impl InputsProcessor for Gemma4ImageProcessor {
fn get_type(&self) -> InputsProcessorType {
InputsProcessorType::Vision
}
fn process_inputs(
&self,
tokenizer: Option<Arc<Tokenizer>>,
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
is_xlora: bool,
device: &Device,
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
return_raw_logits: bool,
sliding_window: Option<usize>,
other_config: Option<Arc<dyn Any>>,
mut paged_attn_metadata: Option<PagedAttentionMeta>,
mapper: Option<&dyn DeviceMapper>,
) -> anyhow::Result<InputProcessorOutput> {
if is_xlora {
return Err(anyhow::Error::msg(
"Cannot make inputs for X-LoRA vision model.",
));
}
if no_kv_cache {
return Err(anyhow::Error::msg("Vision model must have kv cache."));
}
let Some(tokenizer) = tokenizer else {
return Err(anyhow::Error::msg(
"Gemma4ImageProcessor requires a specified tokenizer.",
));
};
let config = other_config.expect("Need a PreProcessorConfig config.");
let preprocessor_config: &PreProcessorConfig =
config.downcast_ref().expect("Downcast failed.");
let has_images = input_seqs.iter().any(|seq| seq.has_images());
let has_audios = input_seqs.iter().any(|seq| seq.has_audios());
let has_videos = input_seqs.iter().any(|seq| seq.has_videos());
let mut has_changed_prompt = false;
let mut image_hashes_accum = Vec::new();
let mut image_cached_tokens_accum = Vec::new();
let mut audio_hashes_accum = Vec::new();
let mut audio_cached_tokens_accum = Vec::new();
let mut video_pixel_values_accum = Vec::new();
let mut video_hashes_accum = Vec::new();
let mut video_cached_tokens_accum = Vec::new();
let mut video_sizes_accum = Vec::new();
if has_audios && !self.supports_audio {
return Err(anyhow::Error::msg(
"This image processor does not support audio.",
));
}
let (audio_mel, audio_mel_mask) = if has_audios {
let mut audio_mel_accum = Vec::new();
let mut audio_mask_accum = Vec::new();
let audio_processor = AudioProcessor::new(preprocessor_config);
for seq in input_seqs.iter_mut() {
if let Some(audios) = seq.take_audios() {
let (seq_audio_mel, seq_audio_mask, seq_audio_frame_counts) =
audio_processor.process_audios(&audios, device)?;
let seq_audio_num_tokens = seq_audio_frame_counts
.into_iter()
.map(|num_frames| self.compute_audio_num_tokens(num_frames))
.collect::<Vec<_>>();
if !seq.multimodal.has_changed_prompt {
let mut prompt = tokenizer
.decode(seq.get_toks(), false)
.expect("Detokenization failed!");
let positions: Vec<usize> = prompt
.match_indices(AUDIO_TOKEN)
.map(|(idx, _)| idx)
.collect();
for (i, &pos) in positions.iter().enumerate().rev() {
let num_tokens = seq_audio_num_tokens
.get(i)
.copied()
.unwrap_or(self.audio_seq_length);
let replacement = self.build_audio_sequence(num_tokens);
prompt = format!(
"{}{}{}",
&prompt[..pos],
replacement,
&prompt[pos + AUDIO_TOKEN.len()..],
);
}
seq.set_initial_prompt(prompt.clone());
let toks = tokenizer
.encode_fast(prompt.as_str(), false)
.expect("Tokenization failed!");
let ids = toks.get_ids().to_vec();
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
has_changed_prompt = true;
}
let n_audio = audios.len();
let audio_ranges =
find_image_placeholder_ranges(seq.get_toks(), AUDIO_TOKEN_ID);
let cached_audio_tokens =
cached_tokens_for_ranges(seq.prefix_cache_len(), &audio_ranges);
let seq_audio_hashes = seq.audio_hashes().unwrap_or(&[]);
if n_audio > 0 {
for idx in 0..n_audio {
let total_tokens = audio_ranges
.get(idx)
.map(|(_, length)| *length)
.unwrap_or_else(|| {
seq_audio_num_tokens
.get(idx)
.copied()
.unwrap_or(self.audio_seq_length)
});
let cached_tokens = cached_audio_tokens
.get(idx)
.copied()
.unwrap_or(0)
.min(total_tokens);
if cached_tokens >= total_tokens {
continue;
}
audio_mel_accum.push(seq_audio_mel.get(idx)?.unsqueeze(0)?);
audio_mask_accum.push(seq_audio_mask.get(idx)?.unsqueeze(0)?);
if let Some(&hash) = seq_audio_hashes.get(idx) {
audio_hashes_accum.push(hash);
}
audio_cached_tokens_accum.push(cached_tokens);
}
}
}
}
if !audio_mel_accum.is_empty() {
match (
Tensor::cat(&audio_mel_accum, 0),
Tensor::cat(&audio_mask_accum, 0),
) {
(Ok(mel), Ok(mask)) => (Some(mel), Some(mask)),
(Err(e), _) | (_, Err(e)) => {
return Err(anyhow::Error::from(e));
}
}
} else {
(None, None)
}
} else {
(None, None)
};
let pixel_values = if has_images {
if !self.supports_images {
return Err(anyhow::Error::msg(
"This image processor does not support images.",
));
}
let mut pixel_values_accum = Vec::new();
let mut image_sizes_accum = Vec::new();
for seq in input_seqs.iter_mut() {
let images = seq
.take_images()
.expect("Need to have images by this point.");
let per_image_dims: Vec<(usize, usize)> = images
.iter()
.map(|img| {
let (w, h) = img.dimensions();
self.compute_resize_dims(h as usize, w as usize)
})
.collect::<Result<Vec<_>>>()?;
let PreprocessedImages {
pixel_values,
pixel_attention_mask: _,
image_sizes: _,
num_img_tokens: _,
aspect_ratio_ids: _,
aspect_ratio_mask: _,
num_tiles: _,
image_grid_thw: _,
video_grid_thw: _,
rows: _,
cols: _,
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all,
num_crops: _,
} = self
.preprocess(
images,
vec![],
preprocessor_config,
device,
(usize::MAX, usize::MAX),
)
.expect("Preprocessing failed");
if !seq.multimodal.has_changed_prompt {
let mut prompt = tokenizer
.decode(seq.get_toks(), false)
.expect("Detokenization failed!");
let positions: Vec<usize> = prompt
.match_indices(IMAGE_TOKEN)
.map(|(idx, _)| idx)
.collect();
for (i, &pos) in positions.iter().enumerate().rev() {
let (new_h, new_w) = if i < per_image_dims.len() {
per_image_dims[i]
} else {
let grid_unit = self.pooling_kernel_size * self.patch_size;
(grid_unit, grid_unit)
};
let num_tokens = self.output_tokens_for_size(new_h, new_w);
let replacement = self.build_image_sequence(num_tokens);
prompt = format!(
"{}{}{}",
&prompt[..pos],
replacement,
&prompt[pos + IMAGE_TOKEN.len()..],
);
}
seq.set_initial_prompt(prompt.clone());
let toks = tokenizer
.encode_fast(prompt.as_str(), false)
.expect("Tokenization failed!");
let ids = toks.get_ids().to_vec();
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
has_changed_prompt = true;
}
let n_images = pixel_values.dim(0).unwrap_or(0);
let image_ranges = find_image_placeholder_ranges(seq.get_toks(), IMAGE_TOKEN_ID);
let cached_image_tokens =
cached_tokens_for_ranges(seq.prefix_cache_len(), &image_ranges);
let seq_image_hashes = seq.image_hashes().unwrap_or(&[]);
let image_sizes = image_sizes_all.unwrap_or_default();
for idx in 0..n_images {
let total_tokens = image_ranges
.get(idx)
.map(|(_, length)| *length)
.unwrap_or_else(|| {
image_sizes
.get(idx)
.map(|&(h, w)| self.output_tokens_for_size(h as usize, w as usize))
.unwrap_or(0)
});
let cached_tokens = cached_image_tokens
.get(idx)
.copied()
.unwrap_or(0)
.min(total_tokens);
if cached_tokens >= total_tokens {
continue;
}
pixel_values_accum.push(pixel_values.get(idx)?.unsqueeze(0)?);
if let Some(&size) = image_sizes.get(idx) {
image_sizes_accum.push(size);
}
if let Some(&hash) = seq_image_hashes.get(idx) {
image_hashes_accum.push(hash);
}
image_cached_tokens_accum.push(cached_tokens);
}
}
if pixel_values_accum.is_empty() {
(None, vec![])
} else {
(
Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
image_sizes_accum,
)
}
} else {
(None, vec![])
};
let video_pixel_values = if has_videos {
for seq in input_seqs.iter_mut() {
if !seq.multimodal.has_changed_prompt {
let toks = seq.get_toks();
let video_ranges = find_image_placeholder_ranges(toks, VIDEO_TOKEN_ID);
let already_expanded =
!video_ranges.is_empty() && video_ranges.iter().all(|(_, len)| *len > 1);
if already_expanded {
continue;
}
}
if let Some(videos) = seq.take_videos() {
for video in &videos {
if video.frames.is_empty() {
continue;
}
let (sample_w, sample_h) = video.frames[0].dimensions();
let (new_h, new_w) =
self.compute_video_resize_dims(sample_h as usize, sample_w as usize)?;
let tokens_per_frame = self.video_tokens_for_size(new_h, new_w);
let timestamps = video.timestamp_strings();
if !seq.multimodal.has_changed_prompt {
let mut prompt = tokenizer
.decode(seq.get_toks(), false)
.expect("Detokenization failed!");
if let Some(pos) = prompt.find(VIDEO_TOKEN) {
let replacement =
self.build_video_sequence(×tamps, tokens_per_frame);
prompt = format!(
"{}{}{}",
&prompt[..pos],
replacement,
&prompt[pos + VIDEO_TOKEN.len()..],
);
}
seq.set_initial_prompt(prompt.clone());
let toks = tokenizer
.encode_fast(prompt.as_str(), false)
.expect("Tokenization failed!");
let ids = toks.get_ids().to_vec();
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
has_changed_prompt = true;
}
let do_rescale = preprocessor_config.do_rescale.unwrap_or(true);
let rescale_factor =
preprocessor_config.rescale_factor.unwrap_or(1.0 / 255.0);
let resample = preprocessor_config.resampling.to_filter()?;
for frame in &video.frames {
let frame_rgb = DynamicImage::ImageRgb8(frame.to_rgb8());
let resized =
frame_rgb.resize_exact(new_w as u32, new_h as u32, resample);
let transforms = Transforms {
input: &ToTensorNoNorm,
inner_transforms: &[&do_rescale.then_some(Rescale {
factor: Some(rescale_factor),
})],
};
let tensor = resized.apply(transforms, device)?;
video_pixel_values_accum.push(tensor.unsqueeze(0)?);
video_sizes_accum.push((new_h as u32, new_w as u32));
}
}
let video_ranges =
find_image_placeholder_ranges(seq.get_toks(), VIDEO_TOKEN_ID);
let cached_video_tokens =
cached_tokens_for_ranges(seq.prefix_cache_len(), &video_ranges);
let all_video_cached = !video_ranges.is_empty()
&& video_ranges.iter().enumerate().all(|(i, &(_, length))| {
cached_video_tokens.get(i).copied().unwrap_or(0) >= length
});
if all_video_cached {
let n_frames_this_video: usize =
videos.iter().map(|v| v.frames.len()).sum();
let start = video_pixel_values_accum
.len()
.saturating_sub(n_frames_this_video);
video_pixel_values_accum.truncate(start);
video_sizes_accum.truncate(start);
} else {
let mut global_frame_idx = 0;
for video in &videos {
let frame_hashes = video.frame_hashes();
for hash in frame_hashes {
video_hashes_accum.push(hash);
video_cached_tokens_accum.push(
cached_video_tokens
.get(global_frame_idx)
.copied()
.unwrap_or(0),
);
global_frame_idx += 1;
}
}
}
}
}
if video_pixel_values_accum.is_empty() {
None
} else {
let max_h = video_sizes_accum.iter().map(|(h, _)| *h).max().unwrap_or(0) as usize;
let max_w = video_sizes_accum.iter().map(|(_, w)| *w).max().unwrap_or(0) as usize;
let mut padded = Vec::new();
for (pv, &(h, w)) in video_pixel_values_accum
.iter()
.zip(video_sizes_accum.iter())
{
let h = h as usize;
let w = w as usize;
if h < max_h || w < max_w {
let p =
pv.pad_with_zeros(2, 0, max_h - h)?
.pad_with_zeros(3, 0, max_w - w)?;
padded.push(p);
} else {
padded.push(pv.clone());
}
}
Some(Tensor::cat(&padded, 0)?)
}
} else {
None
};
for seq in input_seqs.iter_mut() {
if seq.mm_features().is_empty() {
let mut features = Vec::new();
if let Some(hashes) = seq.image_hashes().map(|h| h.to_vec()) {
if !hashes.is_empty() {
let ranges = find_image_placeholder_ranges(seq.get_toks(), IMAGE_TOKEN_ID);
features.extend(build_mm_features_from_ranges(&ranges, &hashes, "img"));
}
}
if let Some(audio_hashes) = seq.audio_hashes().map(|h| h.to_vec()) {
if !audio_hashes.is_empty() {
let audio_ranges =
find_image_placeholder_ranges(seq.get_toks(), AUDIO_TOKEN_ID);
features.extend(build_mm_features_from_ranges(
&audio_ranges,
&audio_hashes,
"audio",
));
}
}
if let Some(vid_hashes) = seq.video_hashes().map(|h| h.to_vec()) {
if !vid_hashes.is_empty() {
let video_ranges =
find_image_placeholder_ranges(seq.get_toks(), VIDEO_TOKEN_ID);
features.extend(build_mm_features_from_ranges(
&video_ranges,
&vid_hashes,
"video",
));
}
}
if !features.is_empty() {
features.sort_by_key(|f| f.offset);
seq.set_mm_features(features);
}
}
seq.multimodal.has_changed_prompt |= has_changed_prompt;
}
let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = if is_prompt {
get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks())
.collect::<Vec<_>>(),
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
mapper,
sliding_window,
)
.unwrap()
} else {
get_completion_input(
input_seqs
.iter()
.map(|seq| seq.get_toks())
.collect::<Vec<_>>(),
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
mapper,
sliding_window,
)
.unwrap()
};
let (pixel_values, image_sizes) = if is_prompt {
pixel_values
} else {
(None, vec![])
};
let video_pixel_values = if is_prompt { video_pixel_values } else { None };
let inputs: Box<dyn Any> = Box::new(ModelInputs {
input_ids: input,
seqlen_offsets: positions,
context_lens,
position_ids,
pixel_values,
model_specific_args: Box::new(Gemma4SpecificArgs {
audio_mel,
audio_mel_mask,
image_hashes: if is_prompt {
image_hashes_accum
} else {
vec![]
},
image_cached_tokens: if is_prompt {
image_cached_tokens_accum
} else {
vec![]
},
image_sizes,
audio_hashes: if is_prompt {
audio_hashes_accum
} else {
vec![]
},
audio_cached_tokens: if is_prompt {
audio_cached_tokens_accum
} else {
vec![]
},
video_pixel_values,
video_hashes: if is_prompt {
video_hashes_accum
} else {
vec![]
},
video_cached_tokens: if is_prompt {
video_cached_tokens_accum
} else {
vec![]
},
video_sizes: if is_prompt { video_sizes_accum } else { vec![] },
}),
paged_attn_meta,
flash_meta,
});
Ok(InputProcessorOutput {
inputs,
seq_indices,
})
}
}
impl ImagePreProcessor for Gemma4ImageProcessor {
const DEFAULT_MEAN: [f64; 3] = [0.0, 0.0, 0.0];
const DEFAULT_STD: [f64; 3] = [1.0, 1.0, 1.0];
fn preprocess(
&self,
mut images: Vec<DynamicImage>,
videos: Vec<Vec<DynamicImage>>,
config: &PreProcessorConfig,
device: &Device,
(_bs, _max_num_images): (usize, usize),
) -> Result<PreprocessedImages> {
let _ = videos;
let do_rescale = config.do_rescale.unwrap_or(true);
let rescale_factor = config.rescale_factor.unwrap_or(1.0 / 255.0);
let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
let resample = config.resampling.to_filter()?;
for image in images.iter_mut() {
if do_convert_rgb {
*image = DynamicImage::ImageRgb8(image.to_rgb8());
}
}
let mut pixel_values = Vec::new();
let mut image_sizes = Vec::new();
for image in images {
let (w, h) = image.dimensions();
let (new_h, new_w) = self.compute_resize_dims(h as usize, w as usize)?;
let resized = image.resize_exact(new_w as u32, new_h as u32, resample);
let transforms = Transforms {
input: &ToTensorNoNorm,
inner_transforms: &[&do_rescale.then_some(Rescale {
factor: Some(rescale_factor),
})],
};
let tensor = resized.apply(transforms, device)?;
pixel_values.push(tensor.unsqueeze(0)?);
image_sizes.push((new_h as u32, new_w as u32));
}
let max_h = image_sizes.iter().map(|(h, _)| *h).max().unwrap_or(0) as usize;
let max_w = image_sizes.iter().map(|(_, w)| *w).max().unwrap_or(0) as usize;
let mut padded = Vec::new();
for (pv, &(h, w)) in pixel_values.iter().zip(image_sizes.iter()) {
let h = h as usize;
let w = w as usize;
if h < max_h || w < max_w {
let pad_h = max_h - h;
let pad_w = max_w - w;
let p = pv
.pad_with_zeros(2, 0, pad_h)?
.pad_with_zeros(3, 0, pad_w)?;
padded.push(p);
} else {
padded.push(pv.clone());
}
}
Ok(PreprocessedImages {
pixel_values: Tensor::cat(&padded, 0)?,
pixel_attention_mask: None,
image_sizes: None,
num_img_tokens: None,
aspect_ratio_ids: None,
aspect_ratio_mask: None,
num_tiles: None,
image_grid_thw: None,
video_grid_thw: None,
rows: None,
cols: None,
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: Some(image_sizes),
num_crops: None,
})
}
}
#[cfg(test)]
mod tests {
use super::{cached_tokens_for_ranges, Gemma4Processor};
use crate::vision_models::processor_config::ProcessorConfig;
#[test]
fn defaults_audio_seq_length_to_reference_cap() {
let processor = Gemma4Processor::new(ProcessorConfig::default(), 16, 3, 280, true, true);
assert_eq!(processor.audio_seq_length, 750);
}
#[test]
fn cached_tokens_for_ranges_handles_partial_overlap() {
let ranges = vec![(5, 4), (12, 3), (20, 2)];
assert_eq!(cached_tokens_for_ranges(0, &ranges), vec![0, 0, 0]);
assert_eq!(cached_tokens_for_ranges(7, &ranges), vec![2, 0, 0]);
assert_eq!(cached_tokens_for_ranges(13, &ranges), vec![4, 1, 0]);
assert_eq!(cached_tokens_for_ranges(30, &ranges), vec![4, 3, 2]);
}
}