use std::{borrow::Cow, collections::HashMap};
use image::DynamicImage;
use ndarray::{Array4, ArrayD};
use super::{preprocessor_config::PreProcessorConfig, transforms::TransformError};
fn dim_for_ndim(
ndim: usize,
axis_4d: usize,
axis_5d: usize,
shape: &[usize],
) -> Result<usize, TransformError> {
match ndim {
4 => Ok(shape[axis_4d]),
5 => Ok(shape[axis_5d]),
_ => Err(TransformError::InvalidShape {
expected: format!("4D or 5D pixel_values tensor, got {ndim}D"),
actual: shape.to_vec(),
}),
}
}
#[derive(Debug, Clone)]
pub enum ModelSpecificValue {
Tensor { data: Vec<f32>, shape: Vec<usize> },
IntTensor { data: Vec<i64>, shape: Vec<usize> },
UintTensor { data: Vec<u32>, shape: Vec<usize> },
Int(i64),
Float(f64),
IntVec(Vec<i64>),
UintVec(Vec<u32>),
FloatVec(Vec<f32>),
TupleVec(Vec<(u32, u32)>),
Bool(bool),
}
impl ModelSpecificValue {
pub fn uint_1d(data: Vec<u32>) -> Self {
let len = data.len();
Self::UintTensor {
data,
shape: vec![len],
}
}
pub fn uint_2d(data: Vec<u32>, rows: usize, cols: usize) -> Self {
Self::UintTensor {
data,
shape: vec![rows, cols],
}
}
pub fn int_1d(data: Vec<i64>) -> Self {
let len = data.len();
Self::IntTensor {
data,
shape: vec![len],
}
}
pub fn int_2d(data: Vec<i64>, rows: usize, cols: usize) -> Self {
Self::IntTensor {
data,
shape: vec![rows, cols],
}
}
}
#[derive(Debug, Clone)]
pub struct PreprocessedImages {
pub pixel_values: ArrayD<f32>,
pub num_img_tokens: Vec<usize>,
pub image_sizes: Vec<(u32, u32)>,
pub model_specific: HashMap<String, ModelSpecificValue>,
}
impl PreprocessedImages {
pub fn new(
pixel_values: Array4<f32>,
num_img_tokens: Vec<usize>,
image_sizes: Vec<(u32, u32)>,
) -> Self {
Self {
pixel_values: pixel_values.into_dyn(),
num_img_tokens,
image_sizes,
model_specific: HashMap::new(),
}
}
pub fn new_dynamic(
pixel_values: ArrayD<f32>,
num_img_tokens: Vec<usize>,
image_sizes: Vec<(u32, u32)>,
) -> Self {
Self {
pixel_values,
num_img_tokens,
image_sizes,
model_specific: HashMap::new(),
}
}
pub fn with_extra(mut self, key: impl Into<String>, value: ModelSpecificValue) -> Self {
self.model_specific.insert(key.into(), value);
self
}
pub fn batch_size(&self) -> usize {
self.pixel_values.shape()[0]
}
pub fn channels(&self) -> Result<usize, TransformError> {
dim_for_ndim(self.pixel_values.ndim(), 1, 2, self.pixel_values.shape())
}
pub fn height(&self) -> Result<usize, TransformError> {
dim_for_ndim(self.pixel_values.ndim(), 2, 3, self.pixel_values.shape())
}
pub fn width(&self) -> Result<usize, TransformError> {
dim_for_ndim(self.pixel_values.ndim(), 3, 4, self.pixel_values.shape())
}
pub fn ndim(&self) -> usize {
self.pixel_values.ndim()
}
pub fn total_tokens(&self) -> usize {
self.num_img_tokens.iter().sum()
}
pub fn pixel_values_flat(&self) -> Cow<'_, [f32]> {
match self.pixel_values.as_slice() {
Some(slice) => Cow::Borrowed(slice),
None => Cow::Owned(self.pixel_values.iter().copied().collect()),
}
}
pub fn pixel_values_shape(&self) -> Vec<usize> {
self.pixel_values.shape().to_vec()
}
}
pub trait ImagePreProcessor: Send + Sync {
fn default_mean(&self) -> [f64; 3];
fn default_std(&self) -> [f64; 3];
fn preprocess(
&self,
images: &[DynamicImage],
config: &PreProcessorConfig,
) -> Result<PreprocessedImages, TransformError>;
fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize;
fn model_name(&self) -> &'static str;
fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
config.get_target_size()
}
}
pub struct ImageProcessorRegistry {
processors: HashMap<String, Box<dyn ImagePreProcessor>>,
}
impl ImageProcessorRegistry {
pub fn new() -> Self {
Self {
processors: HashMap::new(),
}
}
pub fn register(&mut self, pattern: impl Into<String>, processor: Box<dyn ImagePreProcessor>) {
self.processors.insert(pattern.into(), processor);
}
pub fn find(&self, model_id: &str) -> Option<&dyn ImagePreProcessor> {
let model_lower = model_id.to_lowercase();
for (pattern, processor) in &self.processors {
if model_lower.contains(&pattern.to_lowercase()) {
return Some(processor.as_ref());
}
}
None
}
pub fn has_processor(&self, model_id: &str) -> bool {
self.find(model_id).is_some()
}
pub fn supported_patterns(&self) -> Vec<&str> {
self.processors.keys().map(|s| s.as_str()).collect()
}
}
impl Default for ImageProcessorRegistry {
fn default() -> Self {
Self::new()
}
}
impl ImageProcessorRegistry {
pub fn with_defaults() -> Self {
let mut registry = Self::new();
registry.register(
"llava-next",
Box::new(super::processors::LlavaNextProcessor::new()),
);
registry.register(
"llava-v1.6",
Box::new(super::processors::LlavaNextProcessor::new()),
);
registry.register("llava", Box::new(super::processors::LlavaProcessor::new()));
registry.register(
"qwen3-vl",
Box::new(super::processors::Qwen3VLProcessor::new()),
);
registry.register(
"qwen3_vl",
Box::new(super::processors::Qwen3VLProcessor::new()),
);
registry.register(
"qwen2-vl",
Box::new(super::processors::Qwen2VLProcessor::new()),
);
registry.register(
"qwen2_vl",
Box::new(super::processors::Qwen2VLProcessor::new()),
);
registry.register(
"qwen2.5-vl",
Box::new(super::processors::Qwen2VLProcessor::new()),
);
registry.register(
"qwen2_5-vl",
Box::new(super::processors::Qwen2VLProcessor::new()),
);
registry.register(
"qwen2_5_vl",
Box::new(super::processors::Qwen2VLProcessor::new()),
);
registry.register(
"phi-3-vision",
Box::new(super::processors::Phi3VisionProcessor::new()),
);
registry.register(
"phi3-vision",
Box::new(super::processors::Phi3VisionProcessor::new()),
);
registry.register(
"llama-4",
Box::new(super::processors::Llama4VisionProcessor::new()),
);
registry.register(
"llama4",
Box::new(super::processors::Llama4VisionProcessor::new()),
);
registry
}
}
#[cfg(test)]
mod tests {
use ndarray::Array4;
use super::*;
use crate::vision::processors::LlavaProcessor;
#[test]
fn test_preprocessed_images_accessors() {
let pixel_values = Array4::<f32>::zeros((2, 3, 336, 336));
let images =
PreprocessedImages::new(pixel_values, vec![576, 576], vec![(640, 480), (800, 600)]);
assert_eq!(images.batch_size(), 2);
assert_eq!(images.channels().unwrap(), 3);
assert_eq!(images.height().unwrap(), 336);
assert_eq!(images.width().unwrap(), 336);
assert_eq!(images.total_tokens(), 1152);
}
#[test]
fn test_preprocessed_images_with_extra() {
let pixel_values = Array4::<f32>::zeros((1, 3, 224, 224));
let images = PreprocessedImages::new(pixel_values, vec![196], vec![(224, 224)])
.with_extra(
"image_grid_thw",
ModelSpecificValue::uint_1d(vec![1, 16, 16]),
)
.with_extra("aspect_ratio_id", ModelSpecificValue::Int(0));
assert!(images.model_specific.contains_key("image_grid_thw"));
assert!(images.model_specific.contains_key("aspect_ratio_id"));
}
#[test]
fn test_model_specific_value_constructors() {
let uint_1d = ModelSpecificValue::uint_1d(vec![1, 2, 3]);
match uint_1d {
ModelSpecificValue::UintTensor { data, shape } => {
assert_eq!(data, vec![1, 2, 3]);
assert_eq!(shape, vec![3]);
}
_ => panic!("Expected UintTensor"),
}
let uint_2d = ModelSpecificValue::uint_2d(vec![1, 2, 3, 4], 2, 2);
match uint_2d {
ModelSpecificValue::UintTensor { data, shape } => {
assert_eq!(data, vec![1, 2, 3, 4]);
assert_eq!(shape, vec![2, 2]);
}
_ => panic!("Expected UintTensor"),
}
let int_1d = ModelSpecificValue::int_1d(vec![1, 2, 3]);
match int_1d {
ModelSpecificValue::IntTensor { data, shape } => {
assert_eq!(data, vec![1, 2, 3]);
assert_eq!(shape, vec![3]);
}
_ => panic!("Expected IntTensor"),
}
let int_2d = ModelSpecificValue::int_2d(vec![1, 2, 3, 4], 2, 2);
match int_2d {
ModelSpecificValue::IntTensor { data, shape } => {
assert_eq!(data, vec![1, 2, 3, 4]);
assert_eq!(shape, vec![2, 2]);
}
_ => panic!("Expected IntTensor"),
}
}
#[test]
fn test_pixel_values_flat() {
let mut pixel_values = Array4::<f32>::zeros((1, 1, 2, 2));
pixel_values[[0, 0, 0, 0]] = 1.0;
pixel_values[[0, 0, 0, 1]] = 2.0;
pixel_values[[0, 0, 1, 0]] = 3.0;
pixel_values[[0, 0, 1, 1]] = 4.0;
let images = PreprocessedImages::new(pixel_values, vec![4], vec![(2, 2)]);
let flat = images.pixel_values_flat();
assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_registry_with_defaults() {
let registry = ImageProcessorRegistry::with_defaults();
assert!(registry.has_processor("llava-hf/llava-1.5-7b-hf"));
assert!(registry.has_processor("liuhaotian/llava-v1.5-7b"));
assert!(registry.has_processor("llava-hf/llava-v1.6-mistral-7b-hf"));
assert!(registry.has_processor("lmms-lab/llava-next-interleave-qwen-7b"));
let processor = registry.find("llava-hf/llava-1.5-7b-hf").unwrap();
assert_eq!(processor.model_name(), "llava");
}
#[test]
fn test_registry_find() {
let mut registry = ImageProcessorRegistry::new();
registry.register("test-model", Box::new(LlavaProcessor::new()));
assert!(registry.has_processor("test-model-7b"));
assert!(registry.has_processor("TEST-MODEL"));
assert!(!registry.has_processor("other-model"));
}
}