use std::collections::HashSet;
use image::{imageops::FilterType, DynamicImage, GenericImageView, Rgb, RgbImage};
use ndarray::{s, Array3, Array4, IxDyn};
use crate::vision::{
image_processor::{ImagePreProcessor, ModelSpecificValue, PreprocessedImages},
preprocessor_config::PreProcessorConfig,
transforms::{self, TransformError},
};
pub const LLAMA4_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
pub const LLAMA4_STD: [f64; 3] = [0.5, 0.5, 0.5];
pub const TILE_SIZE: u32 = 336;
pub const DEFAULT_MAX_PATCHES: usize = 16;
pub const PATCH_SIZE: usize = 14;
#[derive(Debug, Clone)]
pub struct Llama4VisionProcessor {
tile_size: u32,
max_patches: usize,
resize_to_max_canvas: bool,
mean: [f64; 3],
std: [f64; 3],
}
impl Default for Llama4VisionProcessor {
fn default() -> Self {
Self::new()
}
}
impl Llama4VisionProcessor {
pub fn new() -> Self {
Self {
tile_size: TILE_SIZE,
max_patches: DEFAULT_MAX_PATCHES,
resize_to_max_canvas: false,
mean: LLAMA4_MEAN,
std: LLAMA4_STD,
}
}
pub fn with_max_patches(max_patches: usize) -> Self {
Self {
tile_size: TILE_SIZE,
max_patches,
resize_to_max_canvas: false,
mean: LLAMA4_MEAN,
std: LLAMA4_STD,
}
}
pub fn from_preprocessor_config(config: &PreProcessorConfig) -> Self {
Self {
tile_size: config
.size
.as_ref()
.and_then(|s| s.get("height").copied())
.unwrap_or(TILE_SIZE),
max_patches: config.max_image_tiles.unwrap_or(DEFAULT_MAX_PATCHES),
resize_to_max_canvas: false,
mean: config
.image_mean
.as_ref()
.map(|v| [v[0], v[1], v[2]])
.unwrap_or(LLAMA4_MEAN),
std: config
.image_std
.as_ref()
.map(|v| [v[0], v[1], v[2]])
.unwrap_or(LLAMA4_STD),
}
}
pub fn tile_size(&self) -> u32 {
self.tile_size
}
pub fn max_patches(&self) -> usize {
self.max_patches
}
fn get_factors(n: usize) -> HashSet<usize> {
let mut factors = HashSet::new();
for i in 1..=(n as f64).sqrt() as usize {
if n.is_multiple_of(i) {
factors.insert(i);
factors.insert(n / i);
}
}
factors
}
fn find_supported_resolutions(&self) -> Vec<(u32, u32)> {
let mut resolutions = Vec::new();
let tile = self.tile_size;
for chunk_size in (1..=self.max_patches).rev() {
let factors = Self::get_factors(chunk_size);
for &factor in &factors {
let h_tiles = factor;
let w_tiles = chunk_size / factor;
resolutions.push((h_tiles as u32 * tile, w_tiles as u32 * tile));
}
}
resolutions
}
fn get_max_res_without_distortion(
image_size: (u32, u32),
target_size: (u32, u32),
) -> (u32, u32) {
let (orig_h, orig_w) = image_size;
let (target_h, target_w) = target_size;
let scale_w = target_w as f64 / orig_w as f64;
let scale_h = target_h as f64 / orig_h as f64;
if scale_w < scale_h {
let new_w = target_w;
let new_h = (orig_h as f64 * scale_w).floor() as u32;
(new_h.min(target_h), new_w)
} else {
let new_h = target_h;
let new_w = (orig_w as f64 * scale_h).floor() as u32;
(new_h, new_w.min(target_w))
}
}
fn get_best_fit(&self, image_size: (u32, u32)) -> (u32, u32) {
let resolutions = self.find_supported_resolutions();
let (orig_h, orig_w) = image_size;
let scales_and_resolutions: Vec<(f64, (u32, u32))> = resolutions
.iter()
.map(|&(target_h, target_w)| {
let scale_w = target_w as f64 / orig_w as f64;
let scale_h = target_h as f64 / orig_h as f64;
let scale = scale_w.min(scale_h);
(scale, (target_h, target_w))
})
.collect();
let upscaling: Vec<_> = scales_and_resolutions
.iter()
.filter(|(s, _)| *s >= 1.0)
.cloned()
.collect();
let selected_scale = if !upscaling.is_empty() {
if self.resize_to_max_canvas {
upscaling
.iter()
.map(|(s, _)| *s)
.fold(f64::NEG_INFINITY, f64::max)
} else {
upscaling
.iter()
.map(|(s, _)| *s)
.fold(f64::INFINITY, f64::min)
}
} else {
scales_and_resolutions
.iter()
.filter(|(s, _)| *s < 1.0)
.map(|(s, _)| *s)
.fold(f64::NEG_INFINITY, f64::max)
};
let candidates: Vec<_> = scales_and_resolutions
.iter()
.filter(|(s, _)| (*s - selected_scale).abs() < 1e-9)
.map(|(_, res)| *res)
.collect();
if candidates.len() > 1 {
*candidates
.iter()
.min_by_key(|(h, w)| h * w)
.unwrap_or(&candidates[0])
} else {
candidates[0]
}
}
fn pad_image(&self, image: &DynamicImage, target_w: u32, target_h: u32) -> DynamicImage {
let (w, h) = image.dimensions();
if w == target_w && h == target_h {
return image.clone();
}
let black = Rgb([0u8, 0, 0]);
let mut padded = RgbImage::from_pixel(target_w, target_h, black);
image::imageops::overlay(&mut padded, &image.to_rgb8(), 0, 0);
DynamicImage::ImageRgb8(padded)
}
fn split_to_tiles(
&self,
tensor: &Array3<f32>,
num_tiles_h: usize,
num_tiles_w: usize,
) -> Array4<f32> {
let tile = self.tile_size as usize;
let num_tiles = num_tiles_h * num_tiles_w;
let mut tiles = Array4::<f32>::zeros((num_tiles, 3, tile, tile));
for h_idx in 0..num_tiles_h {
for w_idx in 0..num_tiles_w {
let tile_idx = h_idx * num_tiles_w + w_idx;
let y_start = h_idx * tile;
let x_start = w_idx * tile;
let tile_view =
tensor.slice(s![.., y_start..y_start + tile, x_start..x_start + tile]);
tiles.slice_mut(s![tile_idx, .., .., ..]).assign(&tile_view);
}
}
tiles
}
fn create_global_image(&self, image: &DynamicImage) -> Array3<f32> {
let tile = self.tile_size;
let resized = image.resize_exact(tile, tile, FilterType::Triangle);
let mut tensor = transforms::to_tensor(&resized);
transforms::normalize(&mut tensor, &self.mean, &self.std);
tensor
}
fn process_single_image(
&self,
image: &DynamicImage,
) -> Result<(Array4<f32>, (usize, usize)), TransformError> {
let (orig_w, orig_h) = image.dimensions();
let image_size = (orig_h, orig_w);
let target_size = self.get_best_fit(image_size);
let (target_h, target_w) = target_size;
let resize_target = if !self.resize_to_max_canvas {
let tile = self.tile_size;
let new_target_h = target_h.min(orig_h.max(tile));
let new_target_w = target_w.min(orig_w.max(tile));
(new_target_h, new_target_w)
} else {
target_size
};
let new_size = Self::get_max_res_without_distortion(image_size, resize_target);
let (new_h, new_w) = (new_size.0.max(1), new_size.1.max(1));
let resized = image.resize_exact(new_w, new_h, FilterType::Triangle);
let padded = self.pad_image(&resized, target_w, target_h);
let mut tensor = transforms::to_tensor(&padded);
transforms::normalize(&mut tensor, &self.mean, &self.std);
let tile = self.tile_size as usize;
let num_tiles_h = target_h as usize / tile;
let num_tiles_w = target_w as usize / tile;
let tiles = self.split_to_tiles(&tensor, num_tiles_h, num_tiles_w);
let num_tiles = num_tiles_h * num_tiles_w;
let output = if num_tiles > 1 {
let global_tile = self.create_global_image(image);
let mut combined = Array4::<f32>::zeros((num_tiles + 1, 3, tile, tile));
combined
.slice_mut(s![..num_tiles, .., .., ..])
.assign(&tiles);
combined
.slice_mut(s![num_tiles, .., .., ..])
.assign(&global_tile);
combined
} else {
tiles
};
Ok((output, (num_tiles_h, num_tiles_w)))
}
pub fn calculate_num_tokens_for_aspect_ratio(&self, aspect_ratio: (usize, usize)) -> usize {
let (h_tiles, w_tiles) = aspect_ratio;
let num_tiles = h_tiles * w_tiles;
let total_tiles = if num_tiles > 1 {
num_tiles + 1
} else {
num_tiles
};
let tokens_per_tile = (self.tile_size as usize / PATCH_SIZE).pow(2);
total_tiles * tokens_per_tile
}
}
impl ImagePreProcessor for Llama4VisionProcessor {
fn default_mean(&self) -> [f64; 3] {
self.mean
}
fn default_std(&self) -> [f64; 3] {
self.std
}
fn preprocess(
&self,
images: &[DynamicImage],
config: &PreProcessorConfig,
) -> Result<PreprocessedImages, TransformError> {
if images.is_empty() {
return Err(TransformError::InvalidShape {
expected: "non-empty image batch".to_string(),
actual: vec![0],
});
}
let processor = if config.max_image_tiles.is_some()
|| config.image_mean.is_some()
|| config.image_std.is_some()
|| config.size.is_some()
{
Self::from_preprocessor_config(config)
} else {
self.clone()
};
let mut all_outputs = Vec::new();
let mut all_aspect_ratios = Vec::new();
let mut image_sizes = Vec::new();
let mut num_img_tokens = Vec::new();
for image in images {
let (output, aspect_ratio) = processor.process_single_image(image)?;
let tokens = processor.calculate_num_tokens_for_aspect_ratio(aspect_ratio);
all_outputs.push(output);
all_aspect_ratios.push(aspect_ratio);
image_sizes.push((image.height(), image.width()));
num_img_tokens.push(tokens);
}
let max_tiles = all_outputs.iter().map(|o| o.shape()[0]).max().unwrap();
let tile = self.tile_size as usize;
let batch_size = images.len();
let mut pixel_values =
ndarray::ArrayD::<f32>::zeros(IxDyn(&[batch_size, max_tiles, 3, tile, tile]));
for (b, output) in all_outputs.iter().enumerate() {
let num_tiles = output.shape()[0];
for t in 0..num_tiles {
pixel_values
.slice_mut(s![b, t, .., .., ..])
.assign(&output.slice(s![t, .., .., ..]));
}
}
let mut model_specific = std::collections::HashMap::new();
let aspect_ratios_flat: Vec<u32> = all_aspect_ratios
.iter()
.flat_map(|&(h, w)| vec![h as u32, w as u32])
.collect();
model_specific.insert(
"aspect_ratios".to_string(),
ModelSpecificValue::UintTensor {
data: aspect_ratios_flat,
shape: vec![batch_size, 2],
},
);
Ok(PreprocessedImages {
pixel_values: pixel_values.into_dyn(),
num_img_tokens,
image_sizes,
model_specific,
})
}
fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize {
let processor = Self::from_preprocessor_config(config);
let image_size = (height, width);
let target_size = processor.get_best_fit(image_size);
let tile = processor.tile_size as usize;
let num_tiles_h = target_size.0 as usize / tile;
let num_tiles_w = target_size.1 as usize / tile;
processor.calculate_num_tokens_for_aspect_ratio((num_tiles_h, num_tiles_w))
}
fn model_name(&self) -> &'static str {
"llama4-vision"
}
fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
let _ = config;
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image(width: u32, height: u32, color: Rgb<u8>) -> DynamicImage {
DynamicImage::from(RgbImage::from_pixel(width, height, color))
}
#[test]
fn test_llama4_vision_processor_default() {
let processor = Llama4VisionProcessor::new();
assert_eq!(processor.tile_size(), TILE_SIZE);
assert_eq!(processor.max_patches(), DEFAULT_MAX_PATCHES);
assert_eq!(processor.mean, LLAMA4_MEAN);
assert_eq!(processor.std, LLAMA4_STD);
}
#[test]
fn test_get_factors() {
let factors = Llama4VisionProcessor::get_factors(12);
assert!(factors.contains(&1));
assert!(factors.contains(&2));
assert!(factors.contains(&3));
assert!(factors.contains(&4));
assert!(factors.contains(&6));
assert!(factors.contains(&12));
assert_eq!(factors.len(), 6);
}
#[test]
fn test_find_supported_resolutions() {
let processor = Llama4VisionProcessor::with_max_patches(4);
let resolutions = processor.find_supported_resolutions();
let expected: Vec<(u32, u32)> = vec![
(336, 336), (336, 672), (672, 336), (336, 1008), (1008, 336), (672, 672), (336, 1344), (1344, 336), ];
for exp in expected {
assert!(
resolutions.contains(&exp),
"Expected resolution {:?} not found",
exp
);
}
}
#[test]
fn test_get_best_fit_square() {
let processor = Llama4VisionProcessor::new();
let best = processor.get_best_fit((500, 500));
assert!(best.0 == best.1 || (best.0 as i32 - best.1 as i32).abs() <= 336);
}
#[test]
fn test_get_best_fit_wide() {
let processor = Llama4VisionProcessor::new();
let best = processor.get_best_fit((300, 900));
assert!(best.1 >= best.0);
}
#[test]
fn test_get_best_fit_tall() {
let processor = Llama4VisionProcessor::new();
let best = processor.get_best_fit((900, 300));
assert!(best.0 >= best.1);
}
#[test]
fn test_preprocess_square_image() {
let processor = Llama4VisionProcessor::new();
let config = PreProcessorConfig::default();
let image = create_test_image(500, 500, Rgb([128, 128, 128]));
let result = processor.preprocess(&[image], &config).unwrap();
assert_eq!(result.batch_size(), 1);
assert!(result.num_img_tokens[0] > 0);
let flat = result.pixel_values_flat();
assert!(flat.iter().all(|&v| (-1.5..=1.5).contains(&v)));
}
#[test]
fn test_preprocess_wide_image() {
let processor = Llama4VisionProcessor::new();
let config = PreProcessorConfig::default();
let image = create_test_image(1000, 300, Rgb([128, 128, 128]));
let result = processor.preprocess(&[image], &config).unwrap();
assert_eq!(result.batch_size(), 1);
let aspect_ratios = result.model_specific.get("aspect_ratios").unwrap();
if let ModelSpecificValue::UintTensor { data, .. } = aspect_ratios {
let h_tiles = data[0];
let w_tiles = data[1];
assert!(w_tiles >= h_tiles);
}
}
#[test]
fn test_preprocess_multiple_images() {
let processor = Llama4VisionProcessor::new();
let config = PreProcessorConfig::default();
let images = vec![
create_test_image(500, 500, Rgb([100, 100, 100])),
create_test_image(800, 400, Rgb([150, 150, 150])),
];
let result = processor.preprocess(&images, &config).unwrap();
assert_eq!(result.batch_size(), 2);
assert_eq!(result.image_sizes.len(), 2);
assert_eq!(result.num_img_tokens.len(), 2);
}
#[test]
fn test_global_tile_added_for_multiple_tiles() {
let processor = Llama4VisionProcessor::new();
let config = PreProcessorConfig::default();
let image = create_test_image(1000, 1000, Rgb([128, 128, 128]));
let result = processor.preprocess(&[image], &config).unwrap();
let aspect_ratios = result.model_specific.get("aspect_ratios").unwrap();
if let ModelSpecificValue::UintTensor { data, .. } = aspect_ratios {
let h_tiles = data[0] as usize;
let w_tiles = data[1] as usize;
let num_tiles = h_tiles * w_tiles;
if num_tiles > 1 {
let shape = result.pixel_values.shape();
assert_eq!(shape[1], num_tiles + 1);
}
}
}
#[test]
fn test_model_name() {
let processor = Llama4VisionProcessor::new();
assert_eq!(processor.model_name(), "llama4-vision");
}
#[test]
fn test_normalization_values() {
let processor = Llama4VisionProcessor::new();
assert_eq!(processor.default_mean(), [0.5, 0.5, 0.5]);
assert_eq!(processor.default_std(), [0.5, 0.5, 0.5]);
}
#[test]
fn test_token_count_calculation() {
let processor = Llama4VisionProcessor::new();
assert_eq!(processor.calculate_num_tokens_for_aspect_ratio((1, 1)), 576);
assert_eq!(
processor.calculate_num_tokens_for_aspect_ratio((2, 2)),
2880
);
assert_eq!(
processor.calculate_num_tokens_for_aspect_ratio((1, 2)),
1728
);
}
}