use std::collections::HashSet;
use image::{imageops::FilterType, DynamicImage, GenericImageView, Rgb, RgbImage};
use ndarray::{s, Array2, Array3, Array4, IxDyn};
use crate::vision::{
image_processor::{ImagePreProcessor, ModelSpecificValue, PreprocessedImages},
preprocessor_config::PreProcessorConfig,
transforms::{self, TransformError},
};
pub const PHI4_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
pub const PHI4_STD: [f64; 3] = [0.5, 0.5, 0.5];
pub const DEFAULT_DYNAMIC_HD: usize = 36;
pub const BASE_RESOLUTION: u32 = 448;
pub const MASK_RESOLUTION: usize = 32;
pub const PATCH_SIZE: usize = 14;
type SingleImageResult = (Array4<f32>, Array3<u32>, (u32, u32), usize);
#[derive(Debug, Clone)]
pub struct Phi4VisionProcessor {
dynamic_hd: usize,
base_resolution: u32,
mask_resolution: usize,
mean: [f64; 3],
std: [f64; 3],
}
impl Default for Phi4VisionProcessor {
fn default() -> Self {
Self::new()
}
}
impl Phi4VisionProcessor {
pub fn new() -> Self {
Self {
dynamic_hd: DEFAULT_DYNAMIC_HD,
base_resolution: BASE_RESOLUTION,
mask_resolution: MASK_RESOLUTION,
mean: PHI4_MEAN,
std: PHI4_STD,
}
}
pub fn with_dynamic_hd(dynamic_hd: usize) -> Self {
Self {
dynamic_hd,
base_resolution: BASE_RESOLUTION,
mask_resolution: MASK_RESOLUTION,
mean: PHI4_MEAN,
std: PHI4_STD,
}
}
pub fn from_preprocessor_config(config: &PreProcessorConfig) -> Self {
Self {
dynamic_hd: config.dynamic_hd.unwrap_or(DEFAULT_DYNAMIC_HD),
base_resolution: BASE_RESOLUTION,
mask_resolution: MASK_RESOLUTION,
mean: config
.image_mean
.as_ref()
.map(|v| [v[0], v[1], v[2]])
.unwrap_or(PHI4_MEAN),
std: config
.image_std
.as_ref()
.map(|v| [v[0], v[1], v[2]])
.unwrap_or(PHI4_STD),
}
}
pub fn dynamic_hd(&self) -> usize {
self.dynamic_hd
}
pub fn base_resolution(&self) -> u32 {
self.base_resolution
}
fn compute_target_ratios(&self, min_num: usize, max_num: usize) -> Vec<(usize, usize)> {
let mut ratios: HashSet<(usize, usize)> = HashSet::new();
for n in min_num..=max_num {
for i in 1..=(n as f64).sqrt() as usize {
if n % i == 0 {
ratios.insert((i, n / i));
ratios.insert((n / i, i));
}
}
}
let mut sorted_ratios: Vec<(usize, usize)> = ratios.into_iter().collect();
sorted_ratios.sort_by_key(|&(i, j)| i * j);
sorted_ratios
}
fn find_closest_aspect_ratio(
&self,
aspect_ratio: f64,
target_ratios: &[(usize, usize)],
width: u32,
height: u32,
) -> (usize, usize) {
let mut best_ratio_diff = f64::INFINITY;
let mut best_ratio = (1, 1);
let area = (width * height) as f64;
let base_area = (self.base_resolution * self.base_resolution) as f64;
for &(w_ratio, h_ratio) in target_ratios {
let target_aspect_ratio = w_ratio as f64 / h_ratio as f64;
let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
if ratio_diff < best_ratio_diff {
best_ratio_diff = ratio_diff;
best_ratio = (w_ratio, h_ratio);
} else if (ratio_diff - best_ratio_diff).abs() < 1e-6 {
if area > 0.5 * base_area * (w_ratio * h_ratio) as f64 {
best_ratio = (w_ratio, h_ratio);
}
}
}
best_ratio
}
fn dynamic_preprocess(
&self,
image: &DynamicImage,
) -> Result<(DynamicImage, Array2<u32>, usize, usize), TransformError> {
let (orig_w, orig_h) = image.dimensions();
let base_res = self.base_resolution as f64;
let w_crop_num = (orig_w as f64 / base_res).ceil() as usize;
let h_crop_num = (orig_h as f64 / base_res).ceil() as usize;
let (target_w_crops, target_h_crops, target_width, target_height) =
if w_crop_num * h_crop_num > self.dynamic_hd {
let aspect_ratio = orig_w as f64 / orig_h as f64;
let target_ratios = self.compute_target_ratios(1, self.dynamic_hd);
let (w_ratio, h_ratio) =
self.find_closest_aspect_ratio(aspect_ratio, &target_ratios, orig_w, orig_h);
let target_width = self.base_resolution * w_ratio as u32;
let target_height = self.base_resolution * h_ratio as u32;
(w_ratio, h_ratio, target_width, target_height)
} else {
let target_width = self.base_resolution * w_crop_num as u32;
let target_height = self.base_resolution * h_crop_num as u32;
(w_crop_num, h_crop_num, target_width, target_height)
};
let ratio_width = target_width as f64 / orig_w as f64;
let ratio_height = target_height as f64 / orig_h as f64;
let (new_w, new_h, padding_width, padding_height) = if ratio_width < ratio_height {
let new_w = target_width;
let new_h = (orig_h as f64 * ratio_width) as u32;
(new_w, new_h, 0u32, target_height - new_h)
} else {
let new_h = target_height;
let new_w = (orig_w as f64 * ratio_height) as u32;
(new_w, new_h, target_width - new_w, 0u32)
};
let mask_h = self.mask_resolution * target_h_crops;
let mask_w = self.mask_resolution * target_w_crops;
let mut attention_mask = Array2::<u32>::ones((mask_h, mask_w));
if padding_width >= PATCH_SIZE as u32 {
let padding_mask_cols = (padding_width as usize) / PATCH_SIZE;
for row in 0..mask_h {
for col in (mask_w - padding_mask_cols)..mask_w {
attention_mask[[row, col]] = 0;
}
}
}
if padding_height >= PATCH_SIZE as u32 {
let padding_mask_rows = (padding_height as usize) / PATCH_SIZE;
for row in (mask_h - padding_mask_rows)..mask_h {
for col in 0..mask_w {
attention_mask[[row, col]] = 0;
}
}
}
let resized = image.resize_exact(new_w, new_h, FilterType::Triangle);
let padded = self.pad_image(&resized, target_width, target_height);
Ok((padded, attention_mask, target_h_crops, target_w_crops))
}
fn pad_image(&self, image: &DynamicImage, target_w: u32, target_h: u32) -> DynamicImage {
let (w, h) = image.dimensions();
if w == target_w && h == target_h {
return image.clone();
}
let white = Rgb([255u8, 255, 255]);
let mut padded = RgbImage::from_pixel(target_w, target_h, white);
image::imageops::overlay(&mut padded, &image.to_rgb8(), 0, 0);
DynamicImage::ImageRgb8(padded)
}
fn create_global_image(&self, tensor: &Array3<f32>) -> Array3<f32> {
let target = self.base_resolution as usize;
transforms::bicubic_resize(tensor, target, target)
}
fn tile_image(&self, tensor: &Array3<f32>, h_crops: usize, w_crops: usize) -> Array4<f32> {
let base = self.base_resolution as usize;
let num_tiles = h_crops * w_crops;
let mut tiles = Array4::<f32>::zeros((num_tiles, 3, base, base));
for h_idx in 0..h_crops {
for w_idx in 0..w_crops {
let tile_idx = h_idx * w_crops + w_idx;
let y_start = h_idx * base;
let x_start = w_idx * base;
for c in 0..3 {
for y in 0..base {
for x in 0..base {
tiles[[tile_idx, c, y, x]] = tensor[[c, y_start + y, x_start + x]];
}
}
}
}
}
tiles
}
fn downsample_mask(&self, mask: &Array2<u32>, h_crops: usize, w_crops: usize) -> Array2<u32> {
let half_res = self.mask_resolution / 2;
let out_h = h_crops * half_res;
let out_w = w_crops * half_res;
let mut downsampled = Array2::<u32>::zeros((out_h, out_w));
for y in 0..out_h {
for x in 0..out_w {
let src_y = y * 2;
let src_x = x * 2;
if src_y < mask.shape()[0] && src_x < mask.shape()[1] {
downsampled[[y, x]] = mask[[src_y, src_x]];
}
}
}
downsampled
}
fn calculate_num_tokens(&self, downsampled_mask: &Array2<u32>) -> usize {
let mask_sum: u32 = downsampled_mask.iter().sum();
let mask_col0_sum: u32 = downsampled_mask.column(0).iter().sum();
256 + 1 + mask_sum as usize + mask_col0_sum as usize + 16
}
fn process_single_image(
&self,
image: &DynamicImage,
) -> Result<SingleImageResult, TransformError> {
let (hd_image, attention_mask, h_crops, w_crops) = self.dynamic_preprocess(image)?;
let hd_h = hd_image.height();
let hd_w = hd_image.width();
let mut hd_tensor = transforms::to_tensor(&hd_image);
transforms::normalize(&mut hd_tensor, &self.mean, &self.std);
let global_tensor = self.create_global_image(&hd_tensor);
let tiles = self.tile_image(&hd_tensor, h_crops, w_crops);
let num_hd_tiles = h_crops * w_crops;
let base = self.base_resolution as usize;
let total_crops = num_hd_tiles + 1;
let mut output = Array4::<f32>::zeros((total_crops, 3, base, base));
output.slice_mut(s![0, .., .., ..]).assign(&global_tensor);
if num_hd_tiles > 0 {
output.slice_mut(s![1.., .., .., ..]).assign(&tiles);
}
let mask_res = self.mask_resolution;
let mut combined_mask = Array3::<u32>::zeros((total_crops, mask_res, mask_res));
combined_mask.slice_mut(s![0, .., ..]).fill(1);
for h_idx in 0..h_crops {
for w_idx in 0..w_crops {
let tile_idx = h_idx * w_crops + w_idx + 1; let mask_y_start = h_idx * mask_res;
let mask_x_start = w_idx * mask_res;
let tile_mask = attention_mask.slice(s![
mask_y_start..mask_y_start + mask_res,
mask_x_start..mask_x_start + mask_res
]);
combined_mask
.slice_mut(s![tile_idx, .., ..])
.assign(&tile_mask);
}
}
let downsampled = self.downsample_mask(&attention_mask, h_crops, w_crops);
let num_tokens = self.calculate_num_tokens(&downsampled);
Ok((output, combined_mask, (hd_h, hd_w), num_tokens))
}
}
impl ImagePreProcessor for Phi4VisionProcessor {
fn default_mean(&self) -> [f64; 3] {
self.mean
}
fn default_std(&self) -> [f64; 3] {
self.std
}
fn preprocess(
&self,
images: &[DynamicImage],
config: &PreProcessorConfig,
) -> Result<PreprocessedImages, TransformError> {
if images.is_empty() {
return Err(TransformError::InvalidShape {
expected: "non-empty image batch".to_string(),
actual: vec![0],
});
}
let processor = if config.dynamic_hd.is_some() || config.image_mean.is_some() {
Self::from_preprocessor_config(config)
} else {
self.clone()
};
let mut all_outputs = Vec::new();
let mut all_masks = Vec::new();
let mut image_sizes = Vec::new();
let mut num_img_tokens = Vec::new();
for image in images {
let (output, mask, size, tokens) = processor.process_single_image(image)?;
all_outputs.push(output);
all_masks.push(mask);
image_sizes.push(size);
num_img_tokens.push(tokens);
}
let max_crops = all_outputs.iter().map(|o| o.shape()[0]).max().unwrap();
let base = self.base_resolution as usize;
let mask_res = self.mask_resolution;
let batch_size = images.len();
let mut pixel_values =
ndarray::ArrayD::<f32>::zeros(IxDyn(&[batch_size, max_crops, 3, base, base]));
let mut attention_masks =
ndarray::ArrayD::<u32>::zeros(IxDyn(&[batch_size, max_crops, mask_res, mask_res]));
for (b, (output, mask)) in all_outputs.iter().zip(all_masks.iter()).enumerate() {
let num_crops = output.shape()[0];
for t in 0..num_crops {
for c in 0..3 {
for y in 0..base {
for x in 0..base {
pixel_values[[b, t, c, y, x]] = output[[t, c, y, x]];
}
}
}
for y in 0..mask_res {
for x in 0..mask_res {
attention_masks[[b, t, y, x]] = mask[[t, y, x]];
}
}
}
}
let mut model_specific = std::collections::HashMap::new();
let mask_flat: Vec<u32> = attention_masks.iter().copied().collect();
model_specific.insert(
"pixel_attention_mask".to_string(),
ModelSpecificValue::UintTensor {
data: mask_flat,
shape: vec![batch_size, max_crops, mask_res, mask_res],
},
);
let sizes_flat: Vec<u32> = image_sizes.iter().flat_map(|&(h, w)| vec![h, w]).collect();
model_specific.insert(
"image_sizes".to_string(),
ModelSpecificValue::UintTensor {
data: sizes_flat,
shape: vec![batch_size, 2],
},
);
Ok(PreprocessedImages {
pixel_values: pixel_values.into_dyn(),
num_img_tokens,
image_sizes,
model_specific,
})
}
fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize {
let processor = Self::from_preprocessor_config(config);
let base_res = processor.base_resolution as f64;
let w_crop_num = (width as f64 / base_res).ceil() as usize;
let h_crop_num = (height as f64 / base_res).ceil() as usize;
let (target_w_crops, target_h_crops) = if w_crop_num * h_crop_num > processor.dynamic_hd {
let aspect_ratio = width as f64 / height as f64;
let target_ratios = processor.compute_target_ratios(1, processor.dynamic_hd);
processor.find_closest_aspect_ratio(aspect_ratio, &target_ratios, width, height)
} else {
(w_crop_num, h_crop_num)
};
let half_res = processor.mask_resolution / 2;
let mask_area = target_h_crops * target_w_crops * half_res * half_res;
let mask_col0 = target_h_crops * half_res;
256 + 1 + mask_area + mask_col0 + 16
}
fn model_name(&self) -> &'static str {
"phi4-vision"
}
fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
let _ = config;
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image(width: u32, height: u32, color: Rgb<u8>) -> DynamicImage {
DynamicImage::from(RgbImage::from_pixel(width, height, color))
}
#[test]
fn test_phi4_vision_processor_default() {
let processor = Phi4VisionProcessor::new();
assert_eq!(processor.dynamic_hd(), DEFAULT_DYNAMIC_HD);
assert_eq!(processor.base_resolution(), BASE_RESOLUTION);
assert_eq!(processor.mean, PHI4_MEAN);
assert_eq!(processor.std, PHI4_STD);
}
#[test]
fn test_compute_target_ratios() {
let processor = Phi4VisionProcessor::new();
let ratios = processor.compute_target_ratios(1, 4);
assert!(ratios.contains(&(1, 1)));
assert!(ratios.contains(&(2, 2)));
assert!(ratios.contains(&(1, 4)));
assert!(ratios.contains(&(4, 1)));
}
#[test]
fn test_find_closest_aspect_ratio_square() {
let processor = Phi4VisionProcessor::new();
let ratios = processor.compute_target_ratios(1, 36);
let result = processor.find_closest_aspect_ratio(1.0, &ratios, 500, 500);
assert_eq!(result.0, result.1); }
#[test]
fn test_find_closest_aspect_ratio_wide() {
let processor = Phi4VisionProcessor::new();
let ratios = processor.compute_target_ratios(1, 36);
let result = processor.find_closest_aspect_ratio(2.0, &ratios, 1000, 500);
assert!(result.0 > result.1); }
#[test]
fn test_find_closest_aspect_ratio_tall() {
let processor = Phi4VisionProcessor::new();
let ratios = processor.compute_target_ratios(1, 36);
let result = processor.find_closest_aspect_ratio(0.5, &ratios, 500, 1000);
assert!(result.0 < result.1); }
#[test]
fn test_pad_image() {
let processor = Phi4VisionProcessor::new();
let image = create_test_image(300, 200, Rgb([100, 100, 100]));
let padded = processor.pad_image(&image, 448, 448);
assert_eq!(padded.width(), 448);
assert_eq!(padded.height(), 448);
let p = padded.get_pixel(100, 100);
assert_eq!(p.0[0], 100);
let p = padded.get_pixel(400, 400);
assert_eq!(p.0[0], 255);
}
#[test]
fn test_preprocess_square_image() {
let processor = Phi4VisionProcessor::new();
let config = PreProcessorConfig::default();
let image = create_test_image(500, 500, Rgb([128, 128, 128]));
let result = processor.preprocess(&[image], &config).unwrap();
assert_eq!(result.batch_size(), 1);
assert!(result.num_img_tokens[0] > 256);
let flat = result.pixel_values_flat();
assert!(flat.iter().all(|&v| (-1.5..=1.5).contains(&v)));
}
#[test]
fn test_preprocess_wide_image() {
let processor = Phi4VisionProcessor::new();
let config = PreProcessorConfig::default();
let image = create_test_image(1000, 500, Rgb([128, 128, 128]));
let result = processor.preprocess(&[image], &config).unwrap();
assert_eq!(result.batch_size(), 1);
assert!(result.image_sizes[0].1 >= result.image_sizes[0].0);
}
#[test]
fn test_preprocess_multiple_images() {
let processor = Phi4VisionProcessor::new();
let config = PreProcessorConfig::default();
let images = vec![
create_test_image(500, 500, Rgb([100, 100, 100])),
create_test_image(800, 400, Rgb([150, 150, 150])),
];
let result = processor.preprocess(&images, &config).unwrap();
assert_eq!(result.batch_size(), 2);
assert_eq!(result.image_sizes.len(), 2);
assert_eq!(result.num_img_tokens.len(), 2);
}
#[test]
fn test_model_name() {
let processor = Phi4VisionProcessor::new();
assert_eq!(processor.model_name(), "phi4-vision");
}
#[test]
fn test_normalization_values() {
let processor = Phi4VisionProcessor::new();
assert_eq!(processor.default_mean(), [0.5, 0.5, 0.5]);
assert_eq!(processor.default_std(), [0.5, 0.5, 0.5]);
}
#[test]
fn test_phi4_vs_phi3_differences() {
let processor = Phi4VisionProcessor::new();
assert_eq!(processor.base_resolution(), 448);
assert_eq!(processor.mean, [0.5, 0.5, 0.5]);
assert_eq!(processor.std, [0.5, 0.5, 0.5]);
assert_eq!(processor.dynamic_hd(), 36);
}
}