1use std::{borrow::Cow, collections::HashMap};
7
8use image::DynamicImage;
9use ndarray::{Array4, ArrayD};
10
11use super::{preprocessor_config::PreProcessorConfig, transforms::TransformError};
12use crate::types::FieldLayout;
13
14fn dim_for_ndim(
17 ndim: usize,
18 axis_4d: usize,
19 axis_5d: usize,
20 shape: &[usize],
21) -> Result<usize, TransformError> {
22 match ndim {
23 4 => Ok(shape[axis_4d]),
24 5 => Ok(shape[axis_5d]),
25 _ => Err(TransformError::InvalidShape {
26 expected: format!("4D or 5D pixel_values tensor, got {ndim}D"),
27 actual: shape.to_vec(),
28 }),
29 }
30}
31
32#[derive(Debug, Clone)]
37pub enum ModelSpecificValue {
38 Tensor { data: Vec<f32>, shape: Vec<usize> },
40
41 IntTensor { data: Vec<i64>, shape: Vec<usize> },
43
44 UintTensor { data: Vec<u32>, shape: Vec<usize> },
46
47 Int(i64),
49
50 Float(f64),
52
53 IntVec(Vec<i64>),
55
56 UintVec(Vec<u32>),
58
59 FloatVec(Vec<f32>),
61
62 TupleVec(Vec<(u32, u32)>),
64
65 Bool(bool),
67}
68
69impl ModelSpecificValue {
70 pub fn uint_1d(data: Vec<u32>) -> Self {
72 let len = data.len();
73 Self::UintTensor {
74 data,
75 shape: vec![len],
76 }
77 }
78
79 pub fn uint_2d(data: Vec<u32>, rows: usize, cols: usize) -> Self {
81 Self::UintTensor {
82 data,
83 shape: vec![rows, cols],
84 }
85 }
86
87 pub fn int_1d(data: Vec<i64>) -> Self {
89 let len = data.len();
90 Self::IntTensor {
91 data,
92 shape: vec![len],
93 }
94 }
95
96 pub fn int_2d(data: Vec<i64>, rows: usize, cols: usize) -> Self {
98 Self::IntTensor {
99 data,
100 shape: vec![rows, cols],
101 }
102 }
103
104 pub fn first_dim(&self) -> Option<usize> {
106 match self {
107 Self::Tensor { shape, .. }
108 | Self::IntTensor { shape, .. }
109 | Self::UintTensor { shape, .. } => shape.first().copied(),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120pub struct PreprocessedImages {
121 pub pixel_values: ArrayD<f32>,
128
129 pub num_img_tokens: Vec<usize>,
134
135 pub image_sizes: Vec<(u32, u32)>,
139
140 pub model_specific: HashMap<String, ModelSpecificValue>,
147}
148
149impl PreprocessedImages {
150 pub fn new(
152 pixel_values: Array4<f32>,
153 num_img_tokens: Vec<usize>,
154 image_sizes: Vec<(u32, u32)>,
155 ) -> Self {
156 Self {
157 pixel_values: pixel_values.into_dyn(),
158 num_img_tokens,
159 image_sizes,
160 model_specific: HashMap::new(),
161 }
162 }
163
164 pub fn new_dynamic(
168 pixel_values: ArrayD<f32>,
169 num_img_tokens: Vec<usize>,
170 image_sizes: Vec<(u32, u32)>,
171 ) -> Self {
172 Self {
173 pixel_values,
174 num_img_tokens,
175 image_sizes,
176 model_specific: HashMap::new(),
177 }
178 }
179
180 pub fn with_extra(mut self, key: impl Into<String>, value: ModelSpecificValue) -> Self {
182 self.model_specific.insert(key.into(), value);
183 self
184 }
185
186 pub fn batch_size(&self) -> usize {
188 self.pixel_values.shape()[0]
189 }
190
191 pub fn channels(&self) -> Result<usize, TransformError> {
199 dim_for_ndim(self.pixel_values.ndim(), 1, 2, self.pixel_values.shape())
200 }
201
202 pub fn height(&self) -> Result<usize, TransformError> {
210 dim_for_ndim(self.pixel_values.ndim(), 2, 3, self.pixel_values.shape())
211 }
212
213 pub fn width(&self) -> Result<usize, TransformError> {
221 dim_for_ndim(self.pixel_values.ndim(), 3, 4, self.pixel_values.shape())
222 }
223
224 pub fn ndim(&self) -> usize {
226 self.pixel_values.ndim()
227 }
228
229 pub fn total_tokens(&self) -> usize {
231 self.num_img_tokens.iter().sum()
232 }
233
234 pub fn pixel_values_flat(&self) -> Cow<'_, [f32]> {
236 match self.pixel_values.as_slice() {
237 Some(slice) => Cow::Borrowed(slice),
238 None => Cow::Owned(self.pixel_values.iter().copied().collect()),
239 }
240 }
241
242 pub fn pixel_values_shape(&self) -> Vec<usize> {
244 self.pixel_values.shape().to_vec()
245 }
246
247 pub fn num_images(&self) -> usize {
249 self.image_sizes.len()
250 }
251
252 pub fn batched_keys(layouts: &HashMap<String, FieldLayout>) -> Vec<String> {
254 layouts
255 .iter()
256 .filter(|(_, l)| matches!(l, FieldLayout::Batched))
257 .map(|(k, _)| k.clone())
258 .collect()
259 }
260
261 pub fn flat_keys(layouts: &HashMap<String, FieldLayout>) -> HashMap<String, String> {
265 layouts
266 .iter()
267 .filter_map(|(k, l)| match l {
268 FieldLayout::Flat { sizes_key } => Some((k.clone(), sizes_key.clone())),
269 FieldLayout::Batched => None,
270 })
271 .collect()
272 }
273}
274
275pub trait ImagePreProcessor: Send + Sync {
280 fn default_mean(&self) -> [f64; 3];
282
283 fn default_std(&self) -> [f64; 3];
285
286 fn preprocess(
295 &self,
296 images: &[DynamicImage],
297 config: &PreProcessorConfig,
298 ) -> Result<PreprocessedImages, TransformError>;
299
300 fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize;
310
311 fn model_name(&self) -> &'static str;
313
314 fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
318 config.get_target_size()
319 }
320}
321
322pub struct ImageProcessorRegistry {
324 processors: HashMap<String, Box<dyn ImagePreProcessor>>,
325}
326
327impl ImageProcessorRegistry {
328 pub fn new() -> Self {
330 Self {
331 processors: HashMap::new(),
332 }
333 }
334
335 pub fn register(&mut self, pattern: impl Into<String>, processor: Box<dyn ImagePreProcessor>) {
337 self.processors.insert(pattern.into(), processor);
338 }
339
340 pub fn find(&self, model_id: &str, model_type: Option<&str>) -> Option<&dyn ImagePreProcessor> {
344 self.find_in_candidate(model_id)
345 .or_else(|| model_type.and_then(|mt| self.find_in_candidate(mt)))
346 }
347
348 fn find_in_candidate(&self, candidate: &str) -> Option<&dyn ImagePreProcessor> {
349 let candidate = candidate.to_lowercase();
350 for (pattern, processor) in &self.processors {
351 if candidate.contains(&pattern.to_lowercase()) {
352 return Some(processor.as_ref());
353 }
354 }
355 None
356 }
357
358 pub fn supported_patterns(&self) -> Vec<&str> {
360 self.processors.keys().map(|s| s.as_str()).collect()
361 }
362}
363
364impl Default for ImageProcessorRegistry {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370impl ImageProcessorRegistry {
371 pub fn with_defaults() -> Self {
382 let mut registry = Self::new();
383
384 registry.register(
386 "llava-next",
387 Box::new(super::processors::LlavaNextProcessor::new()),
388 );
389 registry.register(
390 "llava_next",
391 Box::new(super::processors::LlavaNextProcessor::new()),
392 );
393 registry.register(
394 "llava-v1.6",
395 Box::new(super::processors::LlavaNextProcessor::new()),
396 );
397
398 registry.register(
402 "llava-1.5",
403 Box::new(super::processors::LlavaProcessor::new()),
404 );
405 registry.register(
406 "llava-v1.5",
407 Box::new(super::processors::LlavaProcessor::new()),
408 );
409
410 registry.register(
412 "qwen3-vl",
413 Box::new(super::processors::Qwen3VLProcessor::new()),
414 );
415 registry.register(
416 "qwen3_vl",
417 Box::new(super::processors::Qwen3VLProcessor::new()),
418 );
419
420 registry.register(
422 "qwen3.5",
423 Box::new(super::processors::Qwen3VLProcessor::new()),
424 );
425 registry.register(
426 "qwen3_5",
427 Box::new(super::processors::Qwen3VLProcessor::new()),
428 );
429 registry.register(
430 "qwen3.6",
431 Box::new(super::processors::Qwen3VLProcessor::new()),
432 );
433 registry.register(
434 "qwen3_6",
435 Box::new(super::processors::Qwen3VLProcessor::new()),
436 );
437
438 registry.register(
440 "qwen2-vl",
441 Box::new(super::processors::Qwen2VLProcessor::new()),
442 );
443 registry.register(
444 "qwen2_vl",
445 Box::new(super::processors::Qwen2VLProcessor::new()),
446 );
447
448 registry.register(
450 "qwen2.5-vl",
451 Box::new(super::processors::Qwen2VLProcessor::new()),
452 );
453 registry.register(
454 "qwen2_5-vl",
455 Box::new(super::processors::Qwen2VLProcessor::new()),
456 );
457 registry.register(
458 "qwen2_5_vl",
459 Box::new(super::processors::Qwen2VLProcessor::new()),
460 );
461
462 registry.register(
464 "phi-3-vision",
465 Box::new(super::processors::Phi3VisionProcessor::new()),
466 );
467 registry.register(
468 "phi3-vision",
469 Box::new(super::processors::Phi3VisionProcessor::new()),
470 );
471 registry.register(
472 "phi3_v",
473 Box::new(super::processors::Phi3VisionProcessor::new()),
474 );
475
476 registry.register(
478 "llama-4",
479 Box::new(super::processors::Llama4VisionProcessor::new()),
480 );
481 registry.register(
482 "llama4",
483 Box::new(super::processors::Llama4VisionProcessor::new()),
484 );
485
486 registry.register(
488 "kimi-k2",
489 Box::new(super::processors::KimiK25Processor::new()),
490 );
491 registry.register(
492 "kimi_k2",
493 Box::new(super::processors::KimiK25Processor::new()),
494 );
495
496 registry
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use ndarray::Array4;
503
504 use super::*;
505 use crate::vision::processors::LlavaProcessor;
506
507 #[test]
508 fn test_preprocessed_images_accessors() {
509 let pixel_values = Array4::<f32>::zeros((2, 3, 336, 336));
510 let images =
511 PreprocessedImages::new(pixel_values, vec![576, 576], vec![(640, 480), (800, 600)]);
512
513 assert_eq!(images.batch_size(), 2);
514 assert_eq!(images.channels().unwrap(), 3);
515 assert_eq!(images.height().unwrap(), 336);
516 assert_eq!(images.width().unwrap(), 336);
517 assert_eq!(images.total_tokens(), 1152);
518 }
519
520 #[test]
521 fn test_preprocessed_images_with_extra() {
522 let pixel_values = Array4::<f32>::zeros((1, 3, 224, 224));
523 let images = PreprocessedImages::new(pixel_values, vec![196], vec![(224, 224)])
524 .with_extra(
525 "image_grid_thw",
526 ModelSpecificValue::uint_1d(vec![1, 16, 16]),
527 )
528 .with_extra("aspect_ratio_id", ModelSpecificValue::Int(0));
529
530 assert!(images.model_specific.contains_key("image_grid_thw"));
531 assert!(images.model_specific.contains_key("aspect_ratio_id"));
532 }
533
534 #[test]
535 fn test_model_specific_value_constructors() {
536 let uint_1d = ModelSpecificValue::uint_1d(vec![1, 2, 3]);
537 match uint_1d {
538 ModelSpecificValue::UintTensor { data, shape } => {
539 assert_eq!(data, vec![1, 2, 3]);
540 assert_eq!(shape, vec![3]);
541 }
542 _ => panic!("Expected UintTensor"),
543 }
544
545 let uint_2d = ModelSpecificValue::uint_2d(vec![1, 2, 3, 4], 2, 2);
546 match uint_2d {
547 ModelSpecificValue::UintTensor { data, shape } => {
548 assert_eq!(data, vec![1, 2, 3, 4]);
549 assert_eq!(shape, vec![2, 2]);
550 }
551 _ => panic!("Expected UintTensor"),
552 }
553
554 let int_1d = ModelSpecificValue::int_1d(vec![1, 2, 3]);
555 match int_1d {
556 ModelSpecificValue::IntTensor { data, shape } => {
557 assert_eq!(data, vec![1, 2, 3]);
558 assert_eq!(shape, vec![3]);
559 }
560 _ => panic!("Expected IntTensor"),
561 }
562
563 let int_2d = ModelSpecificValue::int_2d(vec![1, 2, 3, 4], 2, 2);
564 match int_2d {
565 ModelSpecificValue::IntTensor { data, shape } => {
566 assert_eq!(data, vec![1, 2, 3, 4]);
567 assert_eq!(shape, vec![2, 2]);
568 }
569 _ => panic!("Expected IntTensor"),
570 }
571 }
572
573 #[test]
574 fn test_pixel_values_flat() {
575 let mut pixel_values = Array4::<f32>::zeros((1, 1, 2, 2));
576 pixel_values[[0, 0, 0, 0]] = 1.0;
577 pixel_values[[0, 0, 0, 1]] = 2.0;
578 pixel_values[[0, 0, 1, 0]] = 3.0;
579 pixel_values[[0, 0, 1, 1]] = 4.0;
580
581 let images = PreprocessedImages::new(pixel_values, vec![4], vec![(2, 2)]);
582 let flat = images.pixel_values_flat();
583
584 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
585 }
586
587 #[test]
588 fn test_registry_with_defaults() {
589 let registry = ImageProcessorRegistry::with_defaults();
590
591 assert!(registry.find("llava-hf/llava-1.5-7b-hf", None).is_some());
593 assert!(registry.find("liuhaotian/llava-v1.5-7b", None).is_some());
594
595 assert!(registry
597 .find("llava-hf/llava-v1.6-mistral-7b-hf", None)
598 .is_some());
599 assert!(registry
600 .find("lmms-lab/llava-next-interleave-qwen-7b", None)
601 .is_some());
602
603 let processor = registry.find("llava-hf/llava-1.5-7b-hf", None).unwrap();
605 assert_eq!(processor.model_name(), "llava");
606 }
607
608 #[test]
609 fn test_registry_find() {
610 let mut registry = ImageProcessorRegistry::new();
611
612 registry.register("test-model", Box::new(LlavaProcessor::new()));
614
615 assert!(registry.find("test-model-7b", None).is_some());
616 assert!(registry.find("TEST-MODEL", None).is_some());
617 assert!(registry.find("other-model", None).is_none());
618 }
619
620 #[test]
621 fn test_registry_find_falls_back_to_model_type() {
622 let registry = ImageProcessorRegistry::with_defaults();
623
624 assert!(registry.find("custom-model", None).is_none());
625
626 let processor = registry
627 .find("custom-model", Some("qwen3_vl"))
628 .expect("qwen3 processor by model_type");
629 assert_eq!(processor.model_name(), "qwen3-vl");
630 }
631
632 #[test]
633 fn test_registry_find_preserves_fast_path() {
634 let registry = ImageProcessorRegistry::with_defaults();
635
636 let processor = registry
637 .find("Qwen3-VL-30B-A3B-Instruct", Some("qwen2_vl"))
638 .expect("qwen3 processor by model_id");
639 assert_eq!(processor.model_name(), "qwen3-vl");
640 }
641
642 #[test]
643 fn test_registry_find_phi3_model_type_fallback() {
644 let registry = ImageProcessorRegistry::with_defaults();
645
646 let processor = registry
647 .find("custom-model", Some("phi3_v"))
648 .expect("phi3 processor by model_type");
649 assert_eq!(processor.model_name(), "phi3-vision");
650 }
651}