Skip to main content

llm_multimodal/vision/
image_processor.rs

1//! Image processor trait and output types.
2//!
3//! This module defines the interface for model-specific image processors
4//! and the common output format for preprocessed images.
5
6use std::{borrow::Cow, collections::HashMap};
7
8use image::DynamicImage;
9use ndarray::{Array4, ArrayD};
10
11use super::{preprocessor_config::PreProcessorConfig, transforms::TransformError};
12use crate::types::FieldLayout;
13
14/// Helper to extract a dimension from pixel_values given an ndim-dependent axis index.
15/// Returns `Err` if the ndim is not 4 or 5.
16fn dim_for_ndim(
17    ndim: usize,
18    axis_4d: usize,
19    axis_5d: usize,
20    shape: &[usize],
21) -> Result<usize, TransformError> {
22    match ndim {
23        4 => Ok(shape[axis_4d]),
24        5 => Ok(shape[axis_5d]),
25        _ => Err(TransformError::InvalidShape {
26            expected: format!("4D or 5D pixel_values tensor, got {ndim}D"),
27            actual: shape.to_vec(),
28        }),
29    }
30}
31
32/// Model-specific output values that vary by architecture.
33///
34/// Different vision models require different auxiliary outputs beyond pixel_values.
35/// This enum captures the common types of such outputs.
36#[derive(Debug, Clone)]
37pub enum ModelSpecificValue {
38    /// A tensor with shape information (data as flat vec, shape as dims)
39    Tensor { data: Vec<f32>, shape: Vec<usize> },
40
41    /// A tensor of integers (e.g., aspect_ratio_ids)
42    IntTensor { data: Vec<i64>, shape: Vec<usize> },
43
44    /// A tensor of unsigned integers (e.g., image_grid_thw)
45    UintTensor { data: Vec<u32>, shape: Vec<usize> },
46
47    /// Simple integer value
48    Int(i64),
49
50    /// Simple float value
51    Float(f64),
52
53    /// List of integers
54    IntVec(Vec<i64>),
55
56    /// List of unsigned integers
57    UintVec(Vec<u32>),
58
59    /// List of floats
60    FloatVec(Vec<f32>),
61
62    /// List of tuples (e.g., image sizes)
63    TupleVec(Vec<(u32, u32)>),
64
65    /// Boolean flag
66    Bool(bool),
67}
68
69impl ModelSpecificValue {
70    /// Create a 1D uint tensor from a vector.
71    pub fn uint_1d(data: Vec<u32>) -> Self {
72        let len = data.len();
73        Self::UintTensor {
74            data,
75            shape: vec![len],
76        }
77    }
78
79    /// Create a 2D uint tensor.
80    pub fn uint_2d(data: Vec<u32>, rows: usize, cols: usize) -> Self {
81        Self::UintTensor {
82            data,
83            shape: vec![rows, cols],
84        }
85    }
86
87    /// Create a 1D int tensor from a vector.
88    pub fn int_1d(data: Vec<i64>) -> Self {
89        let len = data.len();
90        Self::IntTensor {
91            data,
92            shape: vec![len],
93        }
94    }
95
96    /// Create a 2D int tensor.
97    pub fn int_2d(data: Vec<i64>, rows: usize, cols: usize) -> Self {
98        Self::IntTensor {
99            data,
100            shape: vec![rows, cols],
101        }
102    }
103
104    /// Get the first dimension of a tensor variant, if applicable.
105    pub fn first_dim(&self) -> Option<usize> {
106        match self {
107            Self::Tensor { shape, .. }
108            | Self::IntTensor { shape, .. }
109            | Self::UintTensor { shape, .. } => shape.first().copied(),
110            _ => None,
111        }
112    }
113}
114
115/// Preprocessed images ready for model consumption.
116///
117/// This struct contains all the outputs needed by the SGLang scheduler
118/// to construct `MultimodalInputs` for the model.
119#[derive(Debug, Clone)]
120pub struct PreprocessedImages {
121    /// Pixel values as a dynamic-dimensional float32 tensor.
122    ///
123    /// This is the primary input to the vision encoder.
124    /// Shape varies by model:
125    /// - Standard: [B, C, H, W] (4D)
126    /// - Phi3-Vision: [B, num_crops+1, C, H, W] (5D)
127    pub pixel_values: ArrayD<f32>,
128
129    /// Number of image tokens per image in the batch.
130    ///
131    /// Used to expand placeholder tokens in the text input.
132    /// For example, LLaVA with 336x336 and patch_size=14 produces 576 tokens.
133    pub num_img_tokens: Vec<usize>,
134
135    /// Original image sizes as (width, height) before preprocessing.
136    ///
137    /// Some models need this for proper attention masking or position encoding.
138    pub image_sizes: Vec<(u32, u32)>,
139
140    /// Model-specific auxiliary outputs.
141    ///
142    /// Examples:
143    /// - Qwen-VL: `image_grid_thw` for rotary position encoding
144    /// - LLaMA-Vision: `aspect_ratio_ids`, `aspect_ratio_mask`
145    /// - Phi3-Vision: `num_img_tokens` per crop
146    pub model_specific: HashMap<String, ModelSpecificValue>,
147}
148
149impl PreprocessedImages {
150    /// Create a new PreprocessedImages with required fields (4D pixel values).
151    pub fn new(
152        pixel_values: Array4<f32>,
153        num_img_tokens: Vec<usize>,
154        image_sizes: Vec<(u32, u32)>,
155    ) -> Self {
156        Self {
157            pixel_values: pixel_values.into_dyn(),
158            num_img_tokens,
159            image_sizes,
160            model_specific: HashMap::new(),
161        }
162    }
163
164    /// Create a new PreprocessedImages with dynamic-dimensional pixel values.
165    ///
166    /// Use this for models like Phi3-Vision that have 5D tensors.
167    pub fn new_dynamic(
168        pixel_values: ArrayD<f32>,
169        num_img_tokens: Vec<usize>,
170        image_sizes: Vec<(u32, u32)>,
171    ) -> Self {
172        Self {
173            pixel_values,
174            num_img_tokens,
175            image_sizes,
176            model_specific: HashMap::new(),
177        }
178    }
179
180    /// Add a model-specific value.
181    pub fn with_extra(mut self, key: impl Into<String>, value: ModelSpecificValue) -> Self {
182        self.model_specific.insert(key.into(), value);
183        self
184    }
185
186    /// Get the batch size.
187    pub fn batch_size(&self) -> usize {
188        self.pixel_values.shape()[0]
189    }
190
191    /// Get the number of channels.
192    ///
193    /// For 4D tensors [B, C, H, W], returns shape[1].
194    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[2].
195    ///
196    /// # Errors
197    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
198    pub fn channels(&self) -> Result<usize, TransformError> {
199        dim_for_ndim(self.pixel_values.ndim(), 1, 2, self.pixel_values.shape())
200    }
201
202    /// Get the height of processed images.
203    ///
204    /// For 4D tensors [B, C, H, W], returns shape[2].
205    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[3].
206    ///
207    /// # Errors
208    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
209    pub fn height(&self) -> Result<usize, TransformError> {
210        dim_for_ndim(self.pixel_values.ndim(), 2, 3, self.pixel_values.shape())
211    }
212
213    /// Get the width of processed images.
214    ///
215    /// For 4D tensors [B, C, H, W], returns shape[3].
216    /// For 5D tensors [B, N, C, H, W] (Phi3-Vision), returns shape[4].
217    ///
218    /// # Errors
219    /// Returns `TransformError::InvalidShape` if pixel_values is not 4D or 5D.
220    pub fn width(&self) -> Result<usize, TransformError> {
221        dim_for_ndim(self.pixel_values.ndim(), 3, 4, self.pixel_values.shape())
222    }
223
224    /// Get the number of dimensions of pixel_values.
225    pub fn ndim(&self) -> usize {
226        self.pixel_values.ndim()
227    }
228
229    /// Get total number of image tokens across all images.
230    pub fn total_tokens(&self) -> usize {
231        self.num_img_tokens.iter().sum()
232    }
233
234    /// Get pixel values as a flat f32 slice without copying if possible.
235    pub fn pixel_values_flat(&self) -> Cow<'_, [f32]> {
236        match self.pixel_values.as_slice() {
237            Some(slice) => Cow::Borrowed(slice),
238            None => Cow::Owned(self.pixel_values.iter().copied().collect()),
239        }
240    }
241
242    /// Get the shape of pixel values as a vector.
243    pub fn pixel_values_shape(&self) -> Vec<usize> {
244        self.pixel_values.shape().to_vec()
245    }
246
247    /// Number of images in this batch.
248    pub fn num_images(&self) -> usize {
249        self.image_sizes.len()
250    }
251
252    /// Extract batched tensor keys from explicit field layout declarations.
253    pub fn batched_keys(layouts: &HashMap<String, FieldLayout>) -> Vec<String> {
254        layouts
255            .iter()
256            .filter(|(_, l)| matches!(l, FieldLayout::Batched))
257            .map(|(k, _)| k.clone())
258            .collect()
259    }
260
261    /// Extract flat-slicing tensor keys from explicit field layout declarations.
262    ///
263    /// Returns a map of tensor name → sizes tensor name.
264    pub fn flat_keys(layouts: &HashMap<String, FieldLayout>) -> HashMap<String, String> {
265        layouts
266            .iter()
267            .filter_map(|(k, l)| match l {
268                FieldLayout::Flat { sizes_key } => Some((k.clone(), sizes_key.clone())),
269                FieldLayout::Batched => None,
270            })
271            .collect()
272    }
273}
274
275/// Trait for model-specific image preprocessors.
276///
277/// Each vision model (LLaVA, Qwen-VL, Phi3-Vision, etc.) implements this trait
278/// to provide the correct preprocessing pipeline.
279pub trait ImagePreProcessor: Send + Sync {
280    /// Default normalization mean for this model family.
281    fn default_mean(&self) -> [f64; 3];
282
283    /// Default normalization std for this model family.
284    fn default_std(&self) -> [f64; 3];
285
286    /// Preprocess a batch of images.
287    ///
288    /// # Arguments
289    /// * `images` - Input images to preprocess
290    /// * `config` - Preprocessor configuration from HuggingFace
291    ///
292    /// # Returns
293    /// Preprocessed images ready for the model, or an error.
294    fn preprocess(
295        &self,
296        images: &[DynamicImage],
297        config: &PreProcessorConfig,
298    ) -> Result<PreprocessedImages, TransformError>;
299
300    /// Calculate the number of image tokens for a given image size.
301    ///
302    /// This is used to determine how many placeholder tokens to insert
303    /// in the text input before the image has been fully processed.
304    ///
305    /// # Arguments
306    /// * `width` - Image width after preprocessing
307    /// * `height` - Image height after preprocessing
308    /// * `config` - Preprocessor configuration
309    fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize;
310
311    /// Get the model family name for identification.
312    fn model_name(&self) -> &'static str;
313
314    /// Get the expected image size after preprocessing.
315    ///
316    /// Some models have fixed sizes, others are dynamic.
317    fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
318        config.get_target_size()
319    }
320}
321
322/// Registry of available image processors.
323pub struct ImageProcessorRegistry {
324    processors: HashMap<String, Box<dyn ImagePreProcessor>>,
325}
326
327impl ImageProcessorRegistry {
328    /// Create a new empty registry.
329    pub fn new() -> Self {
330        Self {
331            processors: HashMap::new(),
332        }
333    }
334
335    /// Register a processor for a model pattern.
336    pub fn register(&mut self, pattern: impl Into<String>, processor: Box<dyn ImagePreProcessor>) {
337        self.processors.insert(pattern.into(), processor);
338    }
339
340    /// Find a processor for the given model ID, falling back to model_type.
341    ///
342    /// Matches by substring containment (case-insensitive).
343    pub fn find(&self, model_id: &str, model_type: Option<&str>) -> Option<&dyn ImagePreProcessor> {
344        self.find_in_candidate(model_id)
345            .or_else(|| model_type.and_then(|mt| self.find_in_candidate(mt)))
346    }
347
348    fn find_in_candidate(&self, candidate: &str) -> Option<&dyn ImagePreProcessor> {
349        let candidate = candidate.to_lowercase();
350        for (pattern, processor) in &self.processors {
351            if candidate.contains(&pattern.to_lowercase()) {
352                return Some(processor.as_ref());
353            }
354        }
355        None
356    }
357
358    /// Get list of supported model patterns.
359    pub fn supported_patterns(&self) -> Vec<&str> {
360        self.processors.keys().map(|s| s.as_str()).collect()
361    }
362}
363
364impl Default for ImageProcessorRegistry {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370impl ImageProcessorRegistry {
371    /// Create a registry with all built-in processors registered.
372    ///
373    /// Currently registers:
374    /// - `llava-next` -> LlavaNextProcessor
375    /// - `llava-1.5` / `llava-v1.5` -> LlavaProcessor
376    /// - `qwen2-vl` -> Qwen2VLProcessor
377    /// - `qwen2.5-vl` -> Qwen2VLProcessor (same preprocessing as Qwen2-VL)
378    /// - `qwen3-vl` -> Qwen3VLProcessor (patch_size=16, [0.5,0.5,0.5] normalization)
379    /// - `qwen3.5` / `qwen3_5` -> Qwen3VLProcessor (Qwen3.5 reuses Qwen3-VL preprocessing)
380    /// - `phi-3-vision` -> Phi3VisionProcessor (HD transform with 336x336 tiles)
381    pub fn with_defaults() -> Self {
382        let mut registry = Self::new();
383
384        // LLaVA-NeXT (v1.6+, anyres multi-crop)
385        registry.register(
386            "llava-next",
387            Box::new(super::processors::LlavaNextProcessor::new()),
388        );
389        registry.register(
390            "llava_next",
391            Box::new(super::processors::LlavaNextProcessor::new()),
392        );
393        registry.register(
394            "llava-v1.6",
395            Box::new(super::processors::LlavaNextProcessor::new()),
396        );
397
398        // Standard LLaVA (v1.5, single-patch).
399        // Use specific patterns so they don't accidentally match LLaVA-NeXT
400        // model IDs like "llava-v1.6-*".
401        registry.register(
402            "llava-1.5",
403            Box::new(super::processors::LlavaProcessor::new()),
404        );
405        registry.register(
406            "llava-v1.5",
407            Box::new(super::processors::LlavaProcessor::new()),
408        );
409
410        // Register Qwen3-VL first (more specific pattern - must match before qwen2)
411        registry.register(
412            "qwen3-vl",
413            Box::new(super::processors::Qwen3VLProcessor::new()),
414        );
415        registry.register(
416            "qwen3_vl",
417            Box::new(super::processors::Qwen3VLProcessor::new()),
418        );
419
420        // Qwen3.5 family (and Qwen3.6: same arch) reuses Qwen3-VL preprocessing.
421        registry.register(
422            "qwen3.5",
423            Box::new(super::processors::Qwen3VLProcessor::new()),
424        );
425        registry.register(
426            "qwen3_5",
427            Box::new(super::processors::Qwen3VLProcessor::new()),
428        );
429        registry.register(
430            "qwen3.6",
431            Box::new(super::processors::Qwen3VLProcessor::new()),
432        );
433        registry.register(
434            "qwen3_6",
435            Box::new(super::processors::Qwen3VLProcessor::new()),
436        );
437
438        // Register Qwen2-VL (matches Qwen/Qwen2-VL-*, etc.)
439        registry.register(
440            "qwen2-vl",
441            Box::new(super::processors::Qwen2VLProcessor::new()),
442        );
443        registry.register(
444            "qwen2_vl",
445            Box::new(super::processors::Qwen2VLProcessor::new()),
446        );
447
448        // Register Qwen2.5-VL (uses identical preprocessing to Qwen2-VL)
449        registry.register(
450            "qwen2.5-vl",
451            Box::new(super::processors::Qwen2VLProcessor::new()),
452        );
453        registry.register(
454            "qwen2_5-vl",
455            Box::new(super::processors::Qwen2VLProcessor::new()),
456        );
457        registry.register(
458            "qwen2_5_vl",
459            Box::new(super::processors::Qwen2VLProcessor::new()),
460        );
461
462        // Register Phi3-Vision
463        registry.register(
464            "phi-3-vision",
465            Box::new(super::processors::Phi3VisionProcessor::new()),
466        );
467        registry.register(
468            "phi3-vision",
469            Box::new(super::processors::Phi3VisionProcessor::new()),
470        );
471        registry.register(
472            "phi3_v",
473            Box::new(super::processors::Phi3VisionProcessor::new()),
474        );
475
476        // Register LLaMA 4 Vision
477        registry.register(
478            "llama-4",
479            Box::new(super::processors::Llama4VisionProcessor::new()),
480        );
481        registry.register(
482            "llama4",
483            Box::new(super::processors::Llama4VisionProcessor::new()),
484        );
485
486        // Register Kimi-K2.5 Vision
487        registry.register(
488            "kimi-k2",
489            Box::new(super::processors::KimiK25Processor::new()),
490        );
491        registry.register(
492            "kimi_k2",
493            Box::new(super::processors::KimiK25Processor::new()),
494        );
495
496        registry
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use ndarray::Array4;
503
504    use super::*;
505    use crate::vision::processors::LlavaProcessor;
506
507    #[test]
508    fn test_preprocessed_images_accessors() {
509        let pixel_values = Array4::<f32>::zeros((2, 3, 336, 336));
510        let images =
511            PreprocessedImages::new(pixel_values, vec![576, 576], vec![(640, 480), (800, 600)]);
512
513        assert_eq!(images.batch_size(), 2);
514        assert_eq!(images.channels().unwrap(), 3);
515        assert_eq!(images.height().unwrap(), 336);
516        assert_eq!(images.width().unwrap(), 336);
517        assert_eq!(images.total_tokens(), 1152);
518    }
519
520    #[test]
521    fn test_preprocessed_images_with_extra() {
522        let pixel_values = Array4::<f32>::zeros((1, 3, 224, 224));
523        let images = PreprocessedImages::new(pixel_values, vec![196], vec![(224, 224)])
524            .with_extra(
525                "image_grid_thw",
526                ModelSpecificValue::uint_1d(vec![1, 16, 16]),
527            )
528            .with_extra("aspect_ratio_id", ModelSpecificValue::Int(0));
529
530        assert!(images.model_specific.contains_key("image_grid_thw"));
531        assert!(images.model_specific.contains_key("aspect_ratio_id"));
532    }
533
534    #[test]
535    fn test_model_specific_value_constructors() {
536        let uint_1d = ModelSpecificValue::uint_1d(vec![1, 2, 3]);
537        match uint_1d {
538            ModelSpecificValue::UintTensor { data, shape } => {
539                assert_eq!(data, vec![1, 2, 3]);
540                assert_eq!(shape, vec![3]);
541            }
542            _ => panic!("Expected UintTensor"),
543        }
544
545        let uint_2d = ModelSpecificValue::uint_2d(vec![1, 2, 3, 4], 2, 2);
546        match uint_2d {
547            ModelSpecificValue::UintTensor { data, shape } => {
548                assert_eq!(data, vec![1, 2, 3, 4]);
549                assert_eq!(shape, vec![2, 2]);
550            }
551            _ => panic!("Expected UintTensor"),
552        }
553
554        let int_1d = ModelSpecificValue::int_1d(vec![1, 2, 3]);
555        match int_1d {
556            ModelSpecificValue::IntTensor { data, shape } => {
557                assert_eq!(data, vec![1, 2, 3]);
558                assert_eq!(shape, vec![3]);
559            }
560            _ => panic!("Expected IntTensor"),
561        }
562
563        let int_2d = ModelSpecificValue::int_2d(vec![1, 2, 3, 4], 2, 2);
564        match int_2d {
565            ModelSpecificValue::IntTensor { data, shape } => {
566                assert_eq!(data, vec![1, 2, 3, 4]);
567                assert_eq!(shape, vec![2, 2]);
568            }
569            _ => panic!("Expected IntTensor"),
570        }
571    }
572
573    #[test]
574    fn test_pixel_values_flat() {
575        let mut pixel_values = Array4::<f32>::zeros((1, 1, 2, 2));
576        pixel_values[[0, 0, 0, 0]] = 1.0;
577        pixel_values[[0, 0, 0, 1]] = 2.0;
578        pixel_values[[0, 0, 1, 0]] = 3.0;
579        pixel_values[[0, 0, 1, 1]] = 4.0;
580
581        let images = PreprocessedImages::new(pixel_values, vec![4], vec![(2, 2)]);
582        let flat = images.pixel_values_flat();
583
584        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
585    }
586
587    #[test]
588    fn test_registry_with_defaults() {
589        let registry = ImageProcessorRegistry::with_defaults();
590
591        // Should find LLaVA processor
592        assert!(registry.find("llava-hf/llava-1.5-7b-hf", None).is_some());
593        assert!(registry.find("liuhaotian/llava-v1.5-7b", None).is_some());
594
595        // Should find LLaVA-NeXT processor
596        assert!(registry
597            .find("llava-hf/llava-v1.6-mistral-7b-hf", None)
598            .is_some());
599        assert!(registry
600            .find("lmms-lab/llava-next-interleave-qwen-7b", None)
601            .is_some());
602
603        // Get the processor and check model name
604        let processor = registry.find("llava-hf/llava-1.5-7b-hf", None).unwrap();
605        assert_eq!(processor.model_name(), "llava");
606    }
607
608    #[test]
609    fn test_registry_find() {
610        let mut registry = ImageProcessorRegistry::new();
611
612        // Create a mock processor using LlavaProcessor
613        registry.register("test-model", Box::new(LlavaProcessor::new()));
614
615        assert!(registry.find("test-model-7b", None).is_some());
616        assert!(registry.find("TEST-MODEL", None).is_some());
617        assert!(registry.find("other-model", None).is_none());
618    }
619
620    #[test]
621    fn test_registry_find_falls_back_to_model_type() {
622        let registry = ImageProcessorRegistry::with_defaults();
623
624        assert!(registry.find("custom-model", None).is_none());
625
626        let processor = registry
627            .find("custom-model", Some("qwen3_vl"))
628            .expect("qwen3 processor by model_type");
629        assert_eq!(processor.model_name(), "qwen3-vl");
630    }
631
632    #[test]
633    fn test_registry_find_preserves_fast_path() {
634        let registry = ImageProcessorRegistry::with_defaults();
635
636        let processor = registry
637            .find("Qwen3-VL-30B-A3B-Instruct", Some("qwen2_vl"))
638            .expect("qwen3 processor by model_id");
639        assert_eq!(processor.model_name(), "qwen3-vl");
640    }
641
642    #[test]
643    fn test_registry_find_phi3_model_type_fallback() {
644        let registry = ImageProcessorRegistry::with_defaults();
645
646        let processor = registry
647            .find("custom-model", Some("phi3_v"))
648            .expect("phi3 processor by model_type");
649        assert_eq!(processor.model_name(), "phi3-vision");
650    }
651}