use image::{DynamicImage, GenericImageView};
use ndarray::{Array2, Array3};
use crate::vision::{
image_processor::{ImagePreProcessor, ModelSpecificValue, PreprocessedImages},
preprocessor_config::PreProcessorConfig,
transforms::{normalize, pil_to_filter, resize, to_tensor, TransformError},
};
#[inline]
fn round_half_to_even(x: f64) -> f64 {
let rounded = x.round();
if (x - x.floor() - 0.5).abs() < 1e-9 {
if rounded as i64 % 2 != 0 {
return rounded - 1.0;
}
}
rounded
}
#[derive(Debug, Clone)]
pub struct QwenVLConfig {
pub patch_size: usize,
pub merge_size: usize,
pub min_pixels: usize,
pub max_pixels: usize,
pub temporal_patch_size: usize,
pub mean: [f64; 3],
pub std: [f64; 3],
pub model_name: &'static str,
}
#[derive(Debug, Clone)]
pub struct QwenVLProcessorBase {
config: QwenVLConfig,
}
impl QwenVLProcessorBase {
pub fn new(config: QwenVLConfig) -> Self {
Self { config }
}
pub fn patch_size(&self) -> usize {
self.config.patch_size
}
pub fn merge_size(&self) -> usize {
self.config.merge_size
}
pub fn min_pixels(&self) -> usize {
self.config.min_pixels
}
pub fn max_pixels(&self) -> usize {
self.config.max_pixels
}
pub fn temporal_patch_size(&self) -> usize {
self.config.temporal_patch_size
}
#[inline]
pub fn get_factor(&self) -> usize {
self.config.patch_size * self.config.merge_size
}
pub fn smart_resize(
&self,
height: usize,
width: usize,
) -> Result<(usize, usize), TransformError> {
let factor = self.get_factor();
if height < factor || width < factor {
return Err(TransformError::InvalidShape {
expected: format!("dimensions >= {factor} (patch_size * merge_size)"),
actual: vec![height, width],
});
}
let max_dim = height.max(width) as f64;
let min_dim = height.min(width) as f64;
let aspect_ratio = max_dim / min_dim;
if aspect_ratio > 200.0 {
return Err(TransformError::InvalidShape {
expected: "aspect ratio < 200:1".to_string(),
actual: vec![height, width],
});
}
let mut h_bar = round_half_to_even(height as f64 / factor as f64) as usize * factor;
let mut w_bar = round_half_to_even(width as f64 / factor as f64) as usize * factor;
h_bar = h_bar.max(factor);
w_bar = w_bar.max(factor);
if h_bar * w_bar > self.config.max_pixels {
let beta = ((height * width) as f64 / self.config.max_pixels as f64).sqrt();
h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
h_bar = h_bar.max(factor);
w_bar = w_bar.max(factor);
}
else if h_bar * w_bar < self.config.min_pixels {
let beta = (self.config.min_pixels as f64 / (height * width) as f64).sqrt();
h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
}
Ok((h_bar, w_bar))
}
pub fn calculate_grid_thw(
&self,
height: usize,
width: usize,
num_frames: usize,
) -> (usize, usize, usize) {
let grid_t =
num_frames.max(self.config.temporal_patch_size) / self.config.temporal_patch_size;
let grid_h = height / self.config.patch_size;
let grid_w = width / self.config.patch_size;
(grid_t, grid_h, grid_w)
}
pub fn calculate_tokens_from_grid(&self, grid_t: usize, grid_h: usize, grid_w: usize) -> usize {
(grid_t * grid_h * grid_w) / (self.config.merge_size * self.config.merge_size)
}
pub fn reshape_to_patches(
&self,
tensor: &Array3<f32>,
grid_t: usize,
grid_h: usize,
grid_w: usize,
) -> Result<Vec<f32>, TransformError> {
use ndarray::IxDyn;
let channel = tensor.shape()[0];
let height = tensor.shape()[1];
let width = tensor.shape()[2];
let patch_size = self.config.patch_size;
let merge_size = self.config.merge_size;
let temporal_patch_size = self.config.temporal_patch_size;
debug_assert_eq!(
height,
grid_h * patch_size,
"Height must match grid_h * patch_size"
);
debug_assert_eq!(
width,
grid_w * patch_size,
"Width must match grid_w * patch_size"
);
let expanded = tensor
.view()
.insert_axis(ndarray::Axis(0))
.broadcast((temporal_patch_size, channel, height, width))
.ok_or_else(|| TransformError::ShapeError(
format!("Broadcast failed: cannot broadcast [1, {channel}, {height}, {width}] to [{temporal_patch_size}, {channel}, {height}, {width}]")
))?
.to_owned();
let grid_h_merged = grid_h / merge_size;
let grid_w_merged = grid_w / merge_size;
let shape_9d = IxDyn(&[
grid_t,
temporal_patch_size,
channel,
grid_h_merged,
merge_size,
patch_size,
grid_w_merged,
merge_size,
patch_size,
]);
let reshaped = expanded
.into_shape_with_order(shape_9d)
.map_err(|e| TransformError::ShapeError(format!("Reshape to 9D failed: {e}")))?;
let permuted = reshaped.permuted_axes(&[0, 3, 6, 4, 7, 2, 1, 5, 8][..]);
let num_patches = grid_t * grid_h * grid_w;
let patch_features = channel * temporal_patch_size * patch_size * patch_size;
let contiguous = permuted.as_standard_layout().into_owned();
let flat = contiguous
.into_shape_with_order(IxDyn(&[num_patches, patch_features]))
.map_err(|e| {
TransformError::ShapeError(format!(
"Final reshape to [{num_patches}, {patch_features}] failed: {e}"
))
})?;
let (vec, _offset) = flat.into_raw_vec_and_offset();
Ok(vec)
}
}
impl ImagePreProcessor for QwenVLProcessorBase {
fn default_mean(&self) -> [f64; 3] {
self.config.mean
}
fn default_std(&self) -> [f64; 3] {
self.config.std
}
fn preprocess(
&self,
images: &[DynamicImage],
config: &PreProcessorConfig,
) -> Result<PreprocessedImages, TransformError> {
if images.is_empty() {
return Err(TransformError::EmptyBatch);
}
let image_sizes: Vec<(u32, u32)> = images.iter().map(|img| img.dimensions()).collect();
let mean = config.get_image_mean();
let std = config.get_image_std();
let filter = pil_to_filter(config.resampling);
let patch_size = self.config.patch_size;
let temporal_patch_size = self.config.temporal_patch_size;
let patch_features = 3 * temporal_patch_size * patch_size * patch_size;
let mut all_patches: Vec<f32> = Vec::new();
let mut patches_per_image: Vec<i64> = Vec::with_capacity(images.len());
let mut grid_thw_data = Vec::with_capacity(images.len() * 3);
let mut num_img_tokens = Vec::with_capacity(images.len());
for image in images {
let (w, h) = image.dimensions();
let (target_h, target_w) = self.smart_resize(h as usize, w as usize)?;
let resized = if config.do_resize.unwrap_or(true) {
resize(image, target_w as u32, target_h as u32, filter)
} else {
image.clone()
};
let mut tensor = to_tensor(&resized);
if config.do_normalize.unwrap_or(true) {
normalize(&mut tensor, &mean, &std);
}
let (grid_t, grid_h, grid_w) = self.calculate_grid_thw(target_h, target_w, 1);
grid_thw_data.push(grid_t as i64);
grid_thw_data.push(grid_h as i64);
grid_thw_data.push(grid_w as i64);
let num_patches = grid_t * grid_h * grid_w;
let tokens = self.calculate_tokens_from_grid(grid_t, grid_h, grid_w);
num_img_tokens.push(tokens);
let patches = self.reshape_to_patches(&tensor, grid_t, grid_h, grid_w)?;
all_patches.extend(patches);
patches_per_image.push(num_patches as i64);
}
let total_patches: usize = patches_per_image.iter().map(|&n| n as usize).sum();
let pixel_values =
Array2::from_shape_vec((total_patches, patch_features), all_patches).map_err(|e| {
TransformError::ShapeError(format!(
"Failed to create patchified pixel_values [{total_patches}, {patch_features}]: {e}"
))
})?;
let result =
PreprocessedImages::new_dynamic(pixel_values.into_dyn(), num_img_tokens, image_sizes)
.with_extra(
"image_grid_thw",
ModelSpecificValue::int_2d(grid_thw_data, images.len(), 3),
)
.with_extra(
"patches_per_image",
ModelSpecificValue::int_1d(patches_per_image),
);
Ok(result)
}
fn calculate_num_tokens(&self, width: u32, height: u32, _config: &PreProcessorConfig) -> usize {
let (new_height, new_width) = match self.smart_resize(height as usize, width as usize) {
Ok((h, w)) => (h, w),
Err(_) => {
let factor = self.get_factor();
(factor, factor)
}
};
let (grid_t, grid_h, grid_w) = self.calculate_grid_thw(new_height, new_width, 1);
self.calculate_tokens_from_grid(grid_t, grid_h, grid_w)
}
fn model_name(&self) -> &'static str {
self.config.model_name
}
fn get_processed_size(&self, _config: &PreProcessorConfig) -> Option<(u32, u32)> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> QwenVLConfig {
QwenVLConfig {
patch_size: 14,
merge_size: 2,
min_pixels: 256 * 28 * 28,
max_pixels: 1280 * 28 * 28,
temporal_patch_size: 2,
mean: [0.5, 0.5, 0.5],
std: [0.5, 0.5, 0.5],
model_name: "test-qwen-vl",
}
}
#[test]
fn test_qwen_vl_base_factor() {
let processor = QwenVLProcessorBase::new(create_test_config());
assert_eq!(processor.get_factor(), 28); }
#[test]
fn test_smart_resize_within_bounds() {
let processor = QwenVLProcessorBase::new(create_test_config());
let (h, w) = processor.smart_resize(500, 500).unwrap();
assert_eq!(h % 28, 0);
assert_eq!(w % 28, 0);
assert!(h * w >= processor.min_pixels());
assert!(h * w <= processor.max_pixels());
}
#[test]
fn test_smart_resize_extreme_aspect_ratio_error() {
let processor = QwenVLProcessorBase::new(create_test_config());
let result = processor.smart_resize(100, 30000);
assert!(result.is_err());
}
#[test]
fn test_calculate_grid_thw() {
let processor = QwenVLProcessorBase::new(create_test_config());
let (t, h, w) = processor.calculate_grid_thw(448, 448, 1);
assert_eq!(t, 1);
assert_eq!(h, 448 / 14);
assert_eq!(w, 448 / 14);
}
#[test]
fn test_calculate_tokens() {
let processor = QwenVLProcessorBase::new(create_test_config());
let tokens = processor.calculate_tokens_from_grid(1, 32, 32);
assert_eq!(tokens, (32 * 32) / 4);
}
}