use image::{DynamicImage, GenericImageView};
use ndarray::{Array2, Array3};
use crate::vision::{
image_processor::{ImagePreProcessor, ModelSpecificValue, PreprocessedImages},
preprocessor_config::PreProcessorConfig,
transforms::{pil_to_filter, resize, to_tensor, to_tensor_and_normalize, TransformError},
};
#[inline]
fn round_half_to_even(x: f64) -> f64 {
let rounded = x.round();
if (x - x.floor() - 0.5).abs() < 1e-9 {
if rounded as i64 % 2 != 0 {
return rounded - 1.0;
}
}
rounded
}
#[derive(Debug, Clone)]
pub struct QwenVLConfig {
pub patch_size: usize,
pub merge_size: usize,
pub min_pixels: usize,
pub max_pixels: usize,
pub temporal_patch_size: usize,
pub mean: [f64; 3],
pub std: [f64; 3],
pub model_name: &'static str,
}
#[derive(Debug, Clone)]
pub struct QwenVLProcessorBase {
config: QwenVLConfig,
}
impl QwenVLProcessorBase {
pub fn new(config: QwenVLConfig) -> Self {
Self { config }
}
pub fn patch_size(&self) -> usize {
self.config.patch_size
}
pub fn merge_size(&self) -> usize {
self.config.merge_size
}
pub fn min_pixels(&self) -> usize {
self.config.min_pixels
}
pub fn max_pixels(&self) -> usize {
self.config.max_pixels
}
pub fn temporal_patch_size(&self) -> usize {
self.config.temporal_patch_size
}
#[inline]
pub fn get_factor(&self) -> usize {
self.config.patch_size * self.config.merge_size
}
pub fn smart_resize(
&self,
height: usize,
width: usize,
) -> Result<(usize, usize), TransformError> {
let factor = self.get_factor();
if height == 0 || width == 0 {
return Err(TransformError::InvalidShape {
expected: "non-zero dimensions".to_string(),
actual: vec![height, width],
});
}
let max_dim = height.max(width) as f64;
let min_dim = height.min(width) as f64;
let aspect_ratio = max_dim / min_dim;
if aspect_ratio > 200.0 {
return Err(TransformError::InvalidShape {
expected: "aspect ratio < 200:1".to_string(),
actual: vec![height, width],
});
}
let mut h_bar = round_half_to_even(height as f64 / factor as f64) as usize * factor;
let mut w_bar = round_half_to_even(width as f64 / factor as f64) as usize * factor;
h_bar = h_bar.max(factor);
w_bar = w_bar.max(factor);
if h_bar * w_bar > self.config.max_pixels {
let beta = ((height * width) as f64 / self.config.max_pixels as f64).sqrt();
h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
h_bar = h_bar.max(factor);
w_bar = w_bar.max(factor);
}
else if h_bar * w_bar < self.config.min_pixels {
let beta = (self.config.min_pixels as f64 / (height * width) as f64).sqrt();
h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
}
Ok((h_bar, w_bar))
}
pub fn calculate_grid_thw(
&self,
height: usize,
width: usize,
num_frames: usize,
) -> (usize, usize, usize) {
let grid_t =
num_frames.max(self.config.temporal_patch_size) / self.config.temporal_patch_size;
let grid_h = height / self.config.patch_size;
let grid_w = width / self.config.patch_size;
(grid_t, grid_h, grid_w)
}
pub fn calculate_tokens_from_grid(&self, grid_t: usize, grid_h: usize, grid_w: usize) -> usize {
(grid_t * grid_h * grid_w) / (self.config.merge_size * self.config.merge_size)
}
pub fn patchify_into(
&self,
tensor: &Array3<f32>,
grid_t: usize,
grid_h: usize,
grid_w: usize,
output: &mut Vec<f32>,
) -> Result<(), TransformError> {
let channel = tensor.shape()[0];
let height = tensor.shape()[1];
let width = tensor.shape()[2];
let patch_size = self.config.patch_size;
let merge_size = self.config.merge_size;
let temporal_patch_size = self.config.temporal_patch_size;
debug_assert_eq!(
height,
grid_h * patch_size,
"Height must match grid_h * patch_size"
);
debug_assert_eq!(
width,
grid_w * patch_size,
"Width must match grid_w * patch_size"
);
let num_patches = grid_t * grid_h * grid_w;
let patch_features = channel * temporal_patch_size * patch_size * patch_size;
let base_idx = output.len();
output.resize(base_idx + num_patches * patch_features, 0.0);
let data = tensor.as_standard_layout();
let flat = data.as_slice().ok_or_else(|| {
TransformError::ShapeError("tensor not contiguous after as_standard_layout".to_string())
})?;
let planes: Vec<&[f32]> = (0..channel)
.map(|c| &flat[c * height * width..(c + 1) * height * width])
.collect();
let merged_patch = merge_size * patch_size;
let mut out_idx = base_idx;
for _gt in 0..grid_t {
for pr in 0..grid_h / merge_size {
for pc in 0..grid_w / merge_size {
let y0 = pr * merged_patch;
let x0 = pc * merged_patch;
for mh in 0..merge_size {
for mw in 0..merge_size {
for plane in &planes {
for _tp in 0..temporal_patch_size {
for py in 0..patch_size {
let row = (y0 + mh * patch_size + py) * width
+ x0
+ mw * patch_size;
output[out_idx..out_idx + patch_size]
.copy_from_slice(&plane[row..row + patch_size]);
out_idx += patch_size;
}
}
}
}
}
}
}
}
Ok(())
}
}
impl ImagePreProcessor for QwenVLProcessorBase {
fn default_mean(&self) -> [f64; 3] {
self.config.mean
}
fn default_std(&self) -> [f64; 3] {
self.config.std
}
fn preprocess(
&self,
images: &[DynamicImage],
config: &PreProcessorConfig,
) -> Result<PreprocessedImages, TransformError> {
if images.is_empty() {
return Err(TransformError::EmptyBatch);
}
let image_sizes: Vec<(u32, u32)> = images.iter().map(|img| img.dimensions()).collect();
let mean = config.get_image_mean();
let std = config.get_image_std();
let filter = pil_to_filter(config.resampling);
let patch_size = self.config.patch_size;
let temporal_patch_size = self.config.temporal_patch_size;
let patch_features = 3 * temporal_patch_size * patch_size * patch_size;
let estimated_total: usize = images
.iter()
.map(|img| {
let (w, h) = img.dimensions();
(w as usize * h as usize) / (self.config.merge_size * self.config.merge_size)
* patch_features
/ (patch_size * patch_size)
})
.sum();
let mut all_patches: Vec<f32> = Vec::with_capacity(estimated_total);
let mut patches_per_image: Vec<i64> = Vec::with_capacity(images.len());
let mut grid_thw_data = Vec::with_capacity(images.len() * 3);
let mut num_img_tokens = Vec::with_capacity(images.len());
for image in images {
let (w, h) = image.dimensions();
let (target_h, target_w) = self.smart_resize(h as usize, w as usize)?;
let (tw32, th32) = (target_w as u32, target_h as u32);
let needs_resize = config.do_resize.unwrap_or(true) && (w != tw32 || h != th32);
let resized;
let img_ref = if needs_resize {
resized = resize(image, tw32, th32, filter);
&resized
} else {
image
};
let (grid_t, grid_h, grid_w) = self.calculate_grid_thw(target_h, target_w, 1);
grid_thw_data.push(grid_t as i64);
grid_thw_data.push(grid_h as i64);
grid_thw_data.push(grid_w as i64);
let num_patches = grid_t * grid_h * grid_w;
let tokens = self.calculate_tokens_from_grid(grid_t, grid_h, grid_w);
num_img_tokens.push(tokens);
let tensor = if config.do_normalize.unwrap_or(true) {
to_tensor_and_normalize(img_ref, &mean, &std)
} else {
to_tensor(img_ref)
};
self.patchify_into(&tensor, grid_t, grid_h, grid_w, &mut all_patches)?;
patches_per_image.push(num_patches as i64);
}
let total_patches: usize = patches_per_image.iter().map(|&n| n as usize).sum();
let pixel_values =
Array2::from_shape_vec((total_patches, patch_features), all_patches).map_err(|e| {
TransformError::ShapeError(format!(
"Failed to create patchified pixel_values [{total_patches}, {patch_features}]: {e}"
))
})?;
let result =
PreprocessedImages::new_dynamic(pixel_values.into_dyn(), num_img_tokens, image_sizes)
.with_extra(
"image_grid_thw",
ModelSpecificValue::int_2d(grid_thw_data, images.len(), 3),
)
.with_extra(
"patches_per_image",
ModelSpecificValue::int_1d(patches_per_image),
);
Ok(result)
}
fn calculate_num_tokens(&self, width: u32, height: u32, _config: &PreProcessorConfig) -> usize {
let (new_height, new_width) = match self.smart_resize(height as usize, width as usize) {
Ok((h, w)) => (h, w),
Err(_) => {
let factor = self.get_factor();
(factor, factor)
}
};
let (grid_t, grid_h, grid_w) = self.calculate_grid_thw(new_height, new_width, 1);
self.calculate_tokens_from_grid(grid_t, grid_h, grid_w)
}
fn model_name(&self) -> &'static str {
self.config.model_name
}
fn get_processed_size(&self, _config: &PreProcessorConfig) -> Option<(u32, u32)> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> QwenVLConfig {
QwenVLConfig {
patch_size: 14,
merge_size: 2,
min_pixels: 256 * 28 * 28,
max_pixels: 1280 * 28 * 28,
temporal_patch_size: 2,
mean: [0.5, 0.5, 0.5],
std: [0.5, 0.5, 0.5],
model_name: "test-qwen-vl",
}
}
#[test]
fn test_qwen_vl_base_factor() {
let processor = QwenVLProcessorBase::new(create_test_config());
assert_eq!(processor.get_factor(), 28); }
#[test]
fn test_smart_resize_within_bounds() {
let processor = QwenVLProcessorBase::new(create_test_config());
let (h, w) = processor.smart_resize(500, 500).unwrap();
assert_eq!(h % 28, 0);
assert_eq!(w % 28, 0);
assert!(h * w >= processor.min_pixels());
assert!(h * w <= processor.max_pixels());
}
#[test]
fn test_smart_resize_extreme_aspect_ratio_error() {
let processor = QwenVLProcessorBase::new(create_test_config());
let result = processor.smart_resize(100, 30000);
assert!(result.is_err());
}
#[test]
fn test_calculate_grid_thw() {
let processor = QwenVLProcessorBase::new(create_test_config());
let (t, h, w) = processor.calculate_grid_thw(448, 448, 1);
assert_eq!(t, 1);
assert_eq!(h, 448 / 14);
assert_eq!(w, 448 / 14);
}
#[test]
fn test_calculate_tokens() {
let processor = QwenVLProcessorBase::new(create_test_config());
let tokens = processor.calculate_tokens_from_grid(1, 32, 32);
assert_eq!(tokens, (32 * 32) / 4);
}
}