Skip to main content

llm_multimodal/vision/
image_processor.rs

1//! Image processor trait and output types.
2//!
3//! This module defines the interface for model-specific image processors
4//! and the common output format for preprocessed images.
5
6use std::{borrow::Cow, collections::HashMap};
7
8use image::DynamicImage;
9use ndarray::{Array4, ArrayD};
10
11use super::{preprocessor_config::PreProcessorConfig, transforms::TransformError};
12use crate::types::FieldLayout;
13
14/// Helper to extract a dimension from pixel_values given an ndim-dependent axis index.
15/// Returns `Err` if the ndim is not 4 or 5.
16fn dim_for_ndim(
17    ndim: usize,
18    axis_4d: usize,
19    axis_5d: usize,
20    shape: &[usize],
21) -> Result<usize, TransformError> {
22    match ndim {
23        4 => Ok(shape[axis_4d]),
24        5 => Ok(shape[axis_5d]),
25        _ => Err(TransformError::InvalidShape {
26            expected: format!("4D or 5D pixel_values tensor, got {ndim}D"),
27            actual: shape.to_vec(),
28        }),
29    }
30}
31
32/// Model-specific output values that vary by architecture.
33///
34/// Different vision models require different auxiliary outputs beyond pixel_values.
35/// This enum captures the common types of such outputs.
36#[derive(Debug, Clone)]
37pub enum ModelSpecificValue {
38    /// A tensor with shape information (data as flat vec, shape as dims)
39    Tensor { data: Vec<f32>, shape: Vec<usize> },
40
41    /// A tensor of integers (e.g., aspect_ratio_ids)
42    IntTensor { data: Vec<i64>, shape: Vec<usize> },
43
44    /// A tensor of unsigned integers (e.g., image_grid_thw)
45    UintTensor { data: Vec<u32>, shape: Vec<usize> },
46
47    /// Simple integer value
48    Int(i64),
49
50    /// Simple float value
51    Float(f64),
52
53    /// List of integers
54    IntVec(Vec<i64>),
55
56    /// List of unsigned integers
57    UintVec(Vec<u32>),
58
59    /// List of floats
60    FloatVec(Vec<f32>),
61
62    /// List of tuples (e.g., image sizes)
63    TupleVec(Vec<(u32, u32)>),
64
65    /// Boolean flag
66    Bool(bool),
67}
68
69impl ModelSpecificValue {
70    /// Create a 1D uint tensor from a vector.
71    pub fn uint_1d(data: Vec<u32>) -> Self {
72        let len = data.len();
73        Self::UintTensor {
74            data,
75            shape: vec![len],
76        }
77    }
78
79    /// Create a 2D uint tensor.
80    pub fn uint_2d(data: Vec<u32>, rows: usize, cols: usize) -> Self {
81        Self::UintTensor {
82            data,
83            shape: vec![rows, cols],
84        }
85    }
86
87    /// Create a 1D int tensor from a vector.
88    pub fn int_1d(data: Vec<i64>) -> Self {
89        let len = data.len();
90        Self::IntTensor {
91            data,
92            shape: vec![len],
93        }
94    }
95
96    /// Create a 2D int tensor.
97    pub fn int_2d(data: Vec<i64>, rows: usize, cols: usize) -> Self {
98        Self::IntTensor {
99            data,
100            shape: vec![rows, cols],
101        }
102    }
103
104    /// Get the first dimension of a tensor variant, if applicable.
105    pub fn first_dim(&self) -> Option<usize> {
106        match self {
107            Self::Tensor { shape, .. }
108            | Self::IntTensor { shape, .. }
109            | Self::UintTensor { shape, .. } => shape.first().copied(),
110            _ => None,
111        }
112    }
113}
114
115/// Preprocessed images ready for model consumption.
116///
117/// This struct contains all the outputs needed by the SGLang scheduler
118/// to construct `MultimodalInputs` for the model.
119#[derive(Debug, Clone)]
120pub struct PreprocessedImages {
121    /// Pixel values as a dynamic-dimensional float32 tensor.
122    ///
123    /// This is the primary input to the vision encoder.
124    /// Shape varies by model:
125    /// - Standard: [B, C, H, W] (4D)
126    /// - Phi3-Vision: [B, num_crops+1, C, H, W] (5D)
127    pub pixel_values: ArrayD<f32>,
128
129    /// Number of image tokens per image in the batch.
130    ///
131    /// Used to expand placeholder tokens in the text input.
132    /// For example, LLaVA with 336x336 and patch_size=14 produces 576 tokens.
133    pub num_img_tokens: Vec<usize>,
134
135    /// Original image sizes as (width, height) before preprocessing.
136    ///
137    /// Some models need this for proper attention masking or position encoding.
138    pub image_sizes: Vec<(u32, u32)>,
139
140    /// Model-specific auxiliary outputs.
141    ///
142    /// Examples:
143    /// - Qwen-VL: `image_grid_thw` for rotary position encoding
144    /// - LLaMA-Vision: `aspect_ratio_ids`, `aspect_ratio_mask`
145    /// - Phi3-Vision: `num_img_tokens` per crop
146    pub model_specific: HashMap<String, ModelSpecificValue>,
147}
148
149impl PreprocessedImages {
150    /// Create a new PreprocessedImages with required fields (4D pixel values).
151    pub fn new(
152        pixel_values: Array4<f32>,
153        num_img_tokens: Vec<usize>,
154        image_sizes: Vec<(u32, u32)>,
155    ) -> Self {
156        Self {
157            pixel_values: pixel_values.into_dyn(),
158            num_img_tokens,
159            image_sizes,
160            model_specific: HashMap::new(),
161        }
162    }
163
164    /// Create a new PreprocessedImages with dynamic-dimensional pixel values.
165    ///
166    /// Use this for models like Phi3-Vision that have 5D tensors.
167    pub fn new_dynamic(
168        pixel_values: ArrayD<f32>,
169        num_img_tokens: Vec<usize>,
170        image_sizes: Vec<(u32, u32)>,
171    ) -> Self {
172        Self {
173            pixel_values,
174            num_img_tokens,
175            image_sizes,
176            model_specific: HashMap::new(),
177        }
178    }
179
180    /// Add a model-specific value.
181    pub fn with_extra(mut self, key: impl Into<String>, value: ModelSpecificValue) -> Self {
182        self.model_specific.insert(key.into(), value);
183        self
184    }
185
186    /// Get the batch size.
187    pub fn batch_size(&self) -> usize {
188        self.pixel_values.shape()[0]
189    }
190
191    /// Get the number of channels.
192    ///
193    /// For 4D tensors [B, C, H, W], returns shape[1].
194    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[2].
195    ///
196    /// # Errors
197    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
198    pub fn channels(&self) -> Result<usize, TransformError> {
199        dim_for_ndim(self.pixel_values.ndim(), 1, 2, self.pixel_values.shape())
200    }
201
202    /// Get the height of processed images.
203    ///
204    /// For 4D tensors [B, C, H, W], returns shape[2].
205    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[3].
206    ///
207    /// # Errors
208    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
209    pub fn height(&self) -> Result<usize, TransformError> {
210        dim_for_ndim(self.pixel_values.ndim(), 2, 3, self.pixel_values.shape())
211    }
212
213    /// Get the width of processed images.
214    ///
215    /// For 4D tensors [B, C, H, W], returns shape[3].
216    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[4].
217    ///
218    /// # Errors
219    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
220    pub fn width(&self) -> Result<usize, TransformError> {
221        dim_for_ndim(self.pixel_values.ndim(), 3, 4, self.pixel_values.shape())
222    }
223
224    /// Get the number of dimensions of pixel_values.
225    pub fn ndim(&self) -> usize {
226        self.pixel_values.ndim()
227    }
228
229    /// Get total number of image tokens across all images.
230    pub fn total_tokens(&self) -> usize {
231        self.num_img_tokens.iter().sum()
232    }
233
234    /// Get pixel values as a flat f32 slice without copying if possible.
235    pub fn pixel_values_flat(&self) -> Cow<'_, [f32]> {
236        match self.pixel_values.as_slice() {
237            Some(slice) => Cow::Borrowed(slice),
238            None => Cow::Owned(self.pixel_values.iter().copied().collect()),
239        }
240    }
241
242    /// Get the shape of pixel values as a vector.
243    pub fn pixel_values_shape(&self) -> Vec<usize> {
244        self.pixel_values.shape().to_vec()
245    }
246
247    /// Number of images in this batch.
248    pub fn num_images(&self) -> usize {
249        self.image_sizes.len()
250    }
251
252    /// Extract batched tensor keys from explicit field layout declarations.
253    pub fn batched_keys(layouts: &HashMap<String, FieldLayout>) -> Vec<String> {
254        layouts
255            .iter()
256            .filter(|(_, l)| matches!(l, FieldLayout::Batched))
257            .map(|(k, _)| k.clone())
258            .collect()
259    }
260
261    /// Extract flat-slicing tensor keys from explicit field layout declarations.
262    ///
263    /// Returns a map of tensor name → sizes tensor name.
264    pub fn flat_keys(layouts: &HashMap<String, FieldLayout>) -> HashMap<String, String> {
265        layouts
266            .iter()
267            .filter_map(|(k, l)| match l {
268                FieldLayout::Flat { sizes_key } => Some((k.clone(), sizes_key.clone())),
269                FieldLayout::Batched => None,
270            })
271            .collect()
272    }
273}
274
275/// Trait for model-specific image preprocessors.
276///
277/// Each vision model (LLaVA, Qwen-VL, Phi3-Vision, etc.) implements this trait
278/// to provide the correct preprocessing pipeline.
279pub trait ImagePreProcessor: Send + Sync {
280    /// Default normalization mean for this model family.
281    fn default_mean(&self) -> [f64; 3];
282
283    /// Default normalization std for this model family.
284    fn default_std(&self) -> [f64; 3];
285
286    /// Preprocess a batch of images.
287    ///
288    /// # Arguments
289    /// * `images` - Input images to preprocess
290    /// * `config` - Preprocessor configuration from HuggingFace
291    ///
292    /// # Returns
293    /// Preprocessed images ready for the model, or an error.
294    fn preprocess(
295        &self,
296        images: &[DynamicImage],
297        config: &PreProcessorConfig,
298    ) -> Result<PreprocessedImages, TransformError>;
299
300    /// Calculate the number of image tokens for a given image size.
301    ///
302    /// This is used to determine how many placeholder tokens to insert
303    /// in the text input before the image has been fully processed.
304    ///
305    /// # Arguments
306    /// * `width` - Image width after preprocessing
307    /// * `height` - Image height after preprocessing
308    /// * `config` - Preprocessor configuration
309    fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize;
310
311    /// Get the model family name for identification.
312    fn model_name(&self) -> &'static str;
313
314    /// Get the expected image size after preprocessing.
315    ///
316    /// Some models have fixed sizes, others are dynamic.
317    fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
318        config.get_target_size()
319    }
320}
321
322/// Registry of available image processors.
323pub struct ImageProcessorRegistry {
324    processors: HashMap<String, Box<dyn ImagePreProcessor>>,
325}
326
327impl ImageProcessorRegistry {
328    /// Create a new empty registry.
329    pub fn new() -> Self {
330        Self {
331            processors: HashMap::new(),
332        }
333    }
334
335    /// Register a processor for a model pattern.
336    pub fn register(&mut self, pattern: impl Into<String>, processor: Box<dyn ImagePreProcessor>) {
337        self.processors.insert(pattern.into(), processor);
338    }
339
340    /// Find a processor for the given model ID, falling back to model_type.
341    ///
342    /// Matches by substring containment (case-insensitive).
343    pub fn find(&self, model_id: &str, model_type: Option<&str>) -> Option<&dyn ImagePreProcessor> {
344        self.find_in_candidate(model_id)
345            .or_else(|| model_type.and_then(|mt| self.find_in_candidate(mt)))
346    }
347
348    fn find_in_candidate(&self, candidate: &str) -> Option<&dyn ImagePreProcessor> {
349        let candidate = candidate.to_lowercase();
350        for (pattern, processor) in &self.processors {
351            if candidate.contains(&pattern.to_lowercase()) {
352                return Some(processor.as_ref());
353            }
354        }
355        None
356    }
357
358    /// Get list of supported model patterns.
359    pub fn supported_patterns(&self) -> Vec<&str> {
360        self.processors.keys().map(|s| s.as_str()).collect()
361    }
362}
363
364impl Default for ImageProcessorRegistry {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370impl ImageProcessorRegistry {
371    /// Create a registry with all built-in processors registered.
372    ///
373    /// Currently registers:
374    /// - `llava-next` -> LlavaNextProcessor
375    /// - `llava-1.5` / `llava-v1.5` -> LlavaProcessor
376    /// - `qwen2-vl` -> Qwen2VLProcessor
377    /// - `qwen2.5-vl` -> Qwen2VLProcessor (same preprocessing as Qwen2-VL)
378    /// - `qwen3-vl` -> Qwen3VLProcessor (patch_size=16, [0.5,0.5,0.5] normalization)
379    /// - `phi-3-vision` -> Phi3VisionProcessor (HD transform with 336x336 tiles)
380    pub fn with_defaults() -> Self {
381        let mut registry = Self::new();
382
383        // LLaVA-NeXT (v1.6+, anyres multi-crop)
384        registry.register(
385            "llava-next",
386            Box::new(super::processors::LlavaNextProcessor::new()),
387        );
388        registry.register(
389            "llava_next",
390            Box::new(super::processors::LlavaNextProcessor::new()),
391        );
392        registry.register(
393            "llava-v1.6",
394            Box::new(super::processors::LlavaNextProcessor::new()),
395        );
396
397        // Standard LLaVA (v1.5, single-patch).
398        // Use specific patterns so they don't accidentally match LLaVA-NeXT
399        // model IDs like "llava-v1.6-*".
400        registry.register(
401            "llava-1.5",
402            Box::new(super::processors::LlavaProcessor::new()),
403        );
404        registry.register(
405            "llava-v1.5",
406            Box::new(super::processors::LlavaProcessor::new()),
407        );
408
409        // Register Qwen3-VL first (more specific pattern - must match before qwen2)
410        registry.register(
411            "qwen3-vl",
412            Box::new(super::processors::Qwen3VLProcessor::new()),
413        );
414        registry.register(
415            "qwen3_vl",
416            Box::new(super::processors::Qwen3VLProcessor::new()),
417        );
418
419        // Register Qwen2-VL (matches Qwen/Qwen2-VL-*, etc.)
420        registry.register(
421            "qwen2-vl",
422            Box::new(super::processors::Qwen2VLProcessor::new()),
423        );
424        registry.register(
425            "qwen2_vl",
426            Box::new(super::processors::Qwen2VLProcessor::new()),
427        );
428
429        // Register Qwen2.5-VL (uses identical preprocessing to Qwen2-VL)
430        registry.register(
431            "qwen2.5-vl",
432            Box::new(super::processors::Qwen2VLProcessor::new()),
433        );
434        registry.register(
435            "qwen2_5-vl",
436            Box::new(super::processors::Qwen2VLProcessor::new()),
437        );
438        registry.register(
439            "qwen2_5_vl",
440            Box::new(super::processors::Qwen2VLProcessor::new()),
441        );
442
443        // Register Phi3-Vision
444        registry.register(
445            "phi-3-vision",
446            Box::new(super::processors::Phi3VisionProcessor::new()),
447        );
448        registry.register(
449            "phi3-vision",
450            Box::new(super::processors::Phi3VisionProcessor::new()),
451        );
452        registry.register(
453            "phi3_v",
454            Box::new(super::processors::Phi3VisionProcessor::new()),
455        );
456
457        // Register LLaMA 4 Vision
458        registry.register(
459            "llama-4",
460            Box::new(super::processors::Llama4VisionProcessor::new()),
461        );
462        registry.register(
463            "llama4",
464            Box::new(super::processors::Llama4VisionProcessor::new()),
465        );
466
467        registry
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use ndarray::Array4;
474
475    use super::*;
476    use crate::vision::processors::LlavaProcessor;
477
478    #[test]
479    fn test_preprocessed_images_accessors() {
480        let pixel_values = Array4::<f32>::zeros((2, 3, 336, 336));
481        let images =
482            PreprocessedImages::new(pixel_values, vec![576, 576], vec![(640, 480), (800, 600)]);
483
484        assert_eq!(images.batch_size(), 2);
485        assert_eq!(images.channels().unwrap(), 3);
486        assert_eq!(images.height().unwrap(), 336);
487        assert_eq!(images.width().unwrap(), 336);
488        assert_eq!(images.total_tokens(), 1152);
489    }
490
491    #[test]
492    fn test_preprocessed_images_with_extra() {
493        let pixel_values = Array4::<f32>::zeros((1, 3, 224, 224));
494        let images = PreprocessedImages::new(pixel_values, vec![196], vec![(224, 224)])
495            .with_extra(
496                "image_grid_thw",
497                ModelSpecificValue::uint_1d(vec![1, 16, 16]),
498            )
499            .with_extra("aspect_ratio_id", ModelSpecificValue::Int(0));
500
501        assert!(images.model_specific.contains_key("image_grid_thw"));
502        assert!(images.model_specific.contains_key("aspect_ratio_id"));
503    }
504
505    #[test]
506    fn test_model_specific_value_constructors() {
507        let uint_1d = ModelSpecificValue::uint_1d(vec![1, 2, 3]);
508        match uint_1d {
509            ModelSpecificValue::UintTensor { data, shape } => {
510                assert_eq!(data, vec![1, 2, 3]);
511                assert_eq!(shape, vec![3]);
512            }
513            _ => panic!("Expected UintTensor"),
514        }
515
516        let uint_2d = ModelSpecificValue::uint_2d(vec![1, 2, 3, 4], 2, 2);
517        match uint_2d {
518            ModelSpecificValue::UintTensor { data, shape } => {
519                assert_eq!(data, vec![1, 2, 3, 4]);
520                assert_eq!(shape, vec![2, 2]);
521            }
522            _ => panic!("Expected UintTensor"),
523        }
524
525        let int_1d = ModelSpecificValue::int_1d(vec![1, 2, 3]);
526        match int_1d {
527            ModelSpecificValue::IntTensor { data, shape } => {
528                assert_eq!(data, vec![1, 2, 3]);
529                assert_eq!(shape, vec![3]);
530            }
531            _ => panic!("Expected IntTensor"),
532        }
533
534        let int_2d = ModelSpecificValue::int_2d(vec![1, 2, 3, 4], 2, 2);
535        match int_2d {
536            ModelSpecificValue::IntTensor { data, shape } => {
537                assert_eq!(data, vec![1, 2, 3, 4]);
538                assert_eq!(shape, vec![2, 2]);
539            }
540            _ => panic!("Expected IntTensor"),
541        }
542    }
543
544    #[test]
545    fn test_pixel_values_flat() {
546        let mut pixel_values = Array4::<f32>::zeros((1, 1, 2, 2));
547        pixel_values[[0, 0, 0, 0]] = 1.0;
548        pixel_values[[0, 0, 0, 1]] = 2.0;
549        pixel_values[[0, 0, 1, 0]] = 3.0;
550        pixel_values[[0, 0, 1, 1]] = 4.0;
551
552        let images = PreprocessedImages::new(pixel_values, vec![4], vec![(2, 2)]);
553        let flat = images.pixel_values_flat();
554
555        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
556    }
557
558    #[test]
559    fn test_registry_with_defaults() {
560        let registry = ImageProcessorRegistry::with_defaults();
561
562        // Should find LLaVA processor
563        assert!(registry.find("llava-hf/llava-1.5-7b-hf", None).is_some());
564        assert!(registry.find("liuhaotian/llava-v1.5-7b", None).is_some());
565
566        // Should find LLaVA-NeXT processor
567        assert!(registry
568            .find("llava-hf/llava-v1.6-mistral-7b-hf", None)
569            .is_some());
570        assert!(registry
571            .find("lmms-lab/llava-next-interleave-qwen-7b", None)
572            .is_some());
573
574        // Get the processor and check model name
575        let processor = registry.find("llava-hf/llava-1.5-7b-hf", None).unwrap();
576        assert_eq!(processor.model_name(), "llava");
577    }
578
579    #[test]
580    fn test_registry_find() {
581        let mut registry = ImageProcessorRegistry::new();
582
583        // Create a mock processor using LlavaProcessor
584        registry.register("test-model", Box::new(LlavaProcessor::new()));
585
586        assert!(registry.find("test-model-7b", None).is_some());
587        assert!(registry.find("TEST-MODEL", None).is_some());
588        assert!(registry.find("other-model", None).is_none());
589    }
590
591    #[test]
592    fn test_registry_find_falls_back_to_model_type() {
593        let registry = ImageProcessorRegistry::with_defaults();
594
595        assert!(registry.find("custom-model", None).is_none());
596
597        let processor = registry
598            .find("custom-model", Some("qwen3_vl"))
599            .expect("qwen3 processor by model_type");
600        assert_eq!(processor.model_name(), "qwen3-vl");
601    }
602
603    #[test]
604    fn test_registry_find_preserves_fast_path() {
605        let registry = ImageProcessorRegistry::with_defaults();
606
607        let processor = registry
608            .find("Qwen3-VL-30B-A3B-Instruct", Some("qwen2_vl"))
609            .expect("qwen3 processor by model_id");
610        assert_eq!(processor.model_name(), "qwen3-vl");
611    }
612
613    #[test]
614    fn test_registry_find_phi3_model_type_fallback() {
615        let registry = ImageProcessorRegistry::with_defaults();
616
617        let processor = registry
618            .find("custom-model", Some("phi3_v"))
619            .expect("phi3 processor by model_type");
620        assert_eq!(processor.model_name(), "phi3-vision");
621    }
622}