1use std::{borrow::Cow, collections::HashMap};
7
8use image::DynamicImage;
9use ndarray::{Array4, ArrayD};
10
11use super::{preprocessor_config::PreProcessorConfig, transforms::TransformError};
12use crate::types::FieldLayout;
13
14fn dim_for_ndim(
17 ndim: usize,
18 axis_4d: usize,
19 axis_5d: usize,
20 shape: &[usize],
21) -> Result<usize, TransformError> {
22 match ndim {
23 4 => Ok(shape[axis_4d]),
24 5 => Ok(shape[axis_5d]),
25 _ => Err(TransformError::InvalidShape {
26 expected: format!("4D or 5D pixel_values tensor, got {ndim}D"),
27 actual: shape.to_vec(),
28 }),
29 }
30}
31
32#[derive(Debug, Clone)]
37pub enum ModelSpecificValue {
38 Tensor { data: Vec<f32>, shape: Vec<usize> },
40
41 IntTensor { data: Vec<i64>, shape: Vec<usize> },
43
44 UintTensor { data: Vec<u32>, shape: Vec<usize> },
46
47 Int(i64),
49
50 Float(f64),
52
53 IntVec(Vec<i64>),
55
56 UintVec(Vec<u32>),
58
59 FloatVec(Vec<f32>),
61
62 TupleVec(Vec<(u32, u32)>),
64
65 Bool(bool),
67}
68
69impl ModelSpecificValue {
70 pub fn uint_1d(data: Vec<u32>) -> Self {
72 let len = data.len();
73 Self::UintTensor {
74 data,
75 shape: vec![len],
76 }
77 }
78
79 pub fn uint_2d(data: Vec<u32>, rows: usize, cols: usize) -> Self {
81 Self::UintTensor {
82 data,
83 shape: vec![rows, cols],
84 }
85 }
86
87 pub fn int_1d(data: Vec<i64>) -> Self {
89 let len = data.len();
90 Self::IntTensor {
91 data,
92 shape: vec![len],
93 }
94 }
95
96 pub fn int_2d(data: Vec<i64>, rows: usize, cols: usize) -> Self {
98 Self::IntTensor {
99 data,
100 shape: vec![rows, cols],
101 }
102 }
103
104 pub fn first_dim(&self) -> Option<usize> {
106 match self {
107 Self::Tensor { shape, .. }
108 | Self::IntTensor { shape, .. }
109 | Self::UintTensor { shape, .. } => shape.first().copied(),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120pub struct PreprocessedImages {
121 pub pixel_values: ArrayD<f32>,
128
129 pub num_img_tokens: Vec<usize>,
134
135 pub image_sizes: Vec<(u32, u32)>,
139
140 pub model_specific: HashMap<String, ModelSpecificValue>,
147}
148
149impl PreprocessedImages {
150 pub fn new(
152 pixel_values: Array4<f32>,
153 num_img_tokens: Vec<usize>,
154 image_sizes: Vec<(u32, u32)>,
155 ) -> Self {
156 Self {
157 pixel_values: pixel_values.into_dyn(),
158 num_img_tokens,
159 image_sizes,
160 model_specific: HashMap::new(),
161 }
162 }
163
164 pub fn new_dynamic(
168 pixel_values: ArrayD<f32>,
169 num_img_tokens: Vec<usize>,
170 image_sizes: Vec<(u32, u32)>,
171 ) -> Self {
172 Self {
173 pixel_values,
174 num_img_tokens,
175 image_sizes,
176 model_specific: HashMap::new(),
177 }
178 }
179
180 pub fn with_extra(mut self, key: impl Into<String>, value: ModelSpecificValue) -> Self {
182 self.model_specific.insert(key.into(), value);
183 self
184 }
185
186 pub fn batch_size(&self) -> usize {
188 self.pixel_values.shape()[0]
189 }
190
191 pub fn channels(&self) -> Result<usize, TransformError> {
199 dim_for_ndim(self.pixel_values.ndim(), 1, 2, self.pixel_values.shape())
200 }
201
202 pub fn height(&self) -> Result<usize, TransformError> {
210 dim_for_ndim(self.pixel_values.ndim(), 2, 3, self.pixel_values.shape())
211 }
212
213 pub fn width(&self) -> Result<usize, TransformError> {
221 dim_for_ndim(self.pixel_values.ndim(), 3, 4, self.pixel_values.shape())
222 }
223
224 pub fn ndim(&self) -> usize {
226 self.pixel_values.ndim()
227 }
228
229 pub fn total_tokens(&self) -> usize {
231 self.num_img_tokens.iter().sum()
232 }
233
234 pub fn pixel_values_flat(&self) -> Cow<'_, [f32]> {
236 match self.pixel_values.as_slice() {
237 Some(slice) => Cow::Borrowed(slice),
238 None => Cow::Owned(self.pixel_values.iter().copied().collect()),
239 }
240 }
241
242 pub fn pixel_values_shape(&self) -> Vec<usize> {
244 self.pixel_values.shape().to_vec()
245 }
246
247 pub fn num_images(&self) -> usize {
249 self.image_sizes.len()
250 }
251
252 pub fn batched_keys(layouts: &HashMap<String, FieldLayout>) -> Vec<String> {
254 layouts
255 .iter()
256 .filter(|(_, l)| matches!(l, FieldLayout::Batched))
257 .map(|(k, _)| k.clone())
258 .collect()
259 }
260
261 pub fn flat_keys(layouts: &HashMap<String, FieldLayout>) -> HashMap<String, String> {
265 layouts
266 .iter()
267 .filter_map(|(k, l)| match l {
268 FieldLayout::Flat { sizes_key } => Some((k.clone(), sizes_key.clone())),
269 FieldLayout::Batched => None,
270 })
271 .collect()
272 }
273}
274
275pub trait ImagePreProcessor: Send + Sync {
280 fn default_mean(&self) -> [f64; 3];
282
283 fn default_std(&self) -> [f64; 3];
285
286 fn preprocess(
295 &self,
296 images: &[DynamicImage],
297 config: &PreProcessorConfig,
298 ) -> Result<PreprocessedImages, TransformError>;
299
300 fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize;
310
311 fn model_name(&self) -> &'static str;
313
314 fn get_processed_size(&self, config: &PreProcessorConfig) -> Option<(u32, u32)> {
318 config.get_target_size()
319 }
320}
321
322pub struct ImageProcessorRegistry {
324 processors: HashMap<String, Box<dyn ImagePreProcessor>>,
325}
326
327impl ImageProcessorRegistry {
328 pub fn new() -> Self {
330 Self {
331 processors: HashMap::new(),
332 }
333 }
334
335 pub fn register(&mut self, pattern: impl Into<String>, processor: Box<dyn ImagePreProcessor>) {
337 self.processors.insert(pattern.into(), processor);
338 }
339
340 pub fn find(&self, model_id: &str, model_type: Option<&str>) -> Option<&dyn ImagePreProcessor> {
344 self.find_in_candidate(model_id)
345 .or_else(|| model_type.and_then(|mt| self.find_in_candidate(mt)))
346 }
347
348 fn find_in_candidate(&self, candidate: &str) -> Option<&dyn ImagePreProcessor> {
349 let candidate = candidate.to_lowercase();
350 for (pattern, processor) in &self.processors {
351 if candidate.contains(&pattern.to_lowercase()) {
352 return Some(processor.as_ref());
353 }
354 }
355 None
356 }
357
358 pub fn supported_patterns(&self) -> Vec<&str> {
360 self.processors.keys().map(|s| s.as_str()).collect()
361 }
362}
363
364impl Default for ImageProcessorRegistry {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370impl ImageProcessorRegistry {
371 pub fn with_defaults() -> Self {
381 let mut registry = Self::new();
382
383 registry.register(
385 "llava-next",
386 Box::new(super::processors::LlavaNextProcessor::new()),
387 );
388 registry.register(
389 "llava_next",
390 Box::new(super::processors::LlavaNextProcessor::new()),
391 );
392 registry.register(
393 "llava-v1.6",
394 Box::new(super::processors::LlavaNextProcessor::new()),
395 );
396
397 registry.register(
401 "llava-1.5",
402 Box::new(super::processors::LlavaProcessor::new()),
403 );
404 registry.register(
405 "llava-v1.5",
406 Box::new(super::processors::LlavaProcessor::new()),
407 );
408
409 registry.register(
411 "qwen3-vl",
412 Box::new(super::processors::Qwen3VLProcessor::new()),
413 );
414 registry.register(
415 "qwen3_vl",
416 Box::new(super::processors::Qwen3VLProcessor::new()),
417 );
418
419 registry.register(
421 "qwen2-vl",
422 Box::new(super::processors::Qwen2VLProcessor::new()),
423 );
424 registry.register(
425 "qwen2_vl",
426 Box::new(super::processors::Qwen2VLProcessor::new()),
427 );
428
429 registry.register(
431 "qwen2.5-vl",
432 Box::new(super::processors::Qwen2VLProcessor::new()),
433 );
434 registry.register(
435 "qwen2_5-vl",
436 Box::new(super::processors::Qwen2VLProcessor::new()),
437 );
438 registry.register(
439 "qwen2_5_vl",
440 Box::new(super::processors::Qwen2VLProcessor::new()),
441 );
442
443 registry.register(
445 "phi-3-vision",
446 Box::new(super::processors::Phi3VisionProcessor::new()),
447 );
448 registry.register(
449 "phi3-vision",
450 Box::new(super::processors::Phi3VisionProcessor::new()),
451 );
452 registry.register(
453 "phi3_v",
454 Box::new(super::processors::Phi3VisionProcessor::new()),
455 );
456
457 registry.register(
459 "llama-4",
460 Box::new(super::processors::Llama4VisionProcessor::new()),
461 );
462 registry.register(
463 "llama4",
464 Box::new(super::processors::Llama4VisionProcessor::new()),
465 );
466
467 registry
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use ndarray::Array4;
474
475 use super::*;
476 use crate::vision::processors::LlavaProcessor;
477
478 #[test]
479 fn test_preprocessed_images_accessors() {
480 let pixel_values = Array4::<f32>::zeros((2, 3, 336, 336));
481 let images =
482 PreprocessedImages::new(pixel_values, vec![576, 576], vec![(640, 480), (800, 600)]);
483
484 assert_eq!(images.batch_size(), 2);
485 assert_eq!(images.channels().unwrap(), 3);
486 assert_eq!(images.height().unwrap(), 336);
487 assert_eq!(images.width().unwrap(), 336);
488 assert_eq!(images.total_tokens(), 1152);
489 }
490
491 #[test]
492 fn test_preprocessed_images_with_extra() {
493 let pixel_values = Array4::<f32>::zeros((1, 3, 224, 224));
494 let images = PreprocessedImages::new(pixel_values, vec![196], vec![(224, 224)])
495 .with_extra(
496 "image_grid_thw",
497 ModelSpecificValue::uint_1d(vec![1, 16, 16]),
498 )
499 .with_extra("aspect_ratio_id", ModelSpecificValue::Int(0));
500
501 assert!(images.model_specific.contains_key("image_grid_thw"));
502 assert!(images.model_specific.contains_key("aspect_ratio_id"));
503 }
504
505 #[test]
506 fn test_model_specific_value_constructors() {
507 let uint_1d = ModelSpecificValue::uint_1d(vec![1, 2, 3]);
508 match uint_1d {
509 ModelSpecificValue::UintTensor { data, shape } => {
510 assert_eq!(data, vec![1, 2, 3]);
511 assert_eq!(shape, vec![3]);
512 }
513 _ => panic!("Expected UintTensor"),
514 }
515
516 let uint_2d = ModelSpecificValue::uint_2d(vec![1, 2, 3, 4], 2, 2);
517 match uint_2d {
518 ModelSpecificValue::UintTensor { data, shape } => {
519 assert_eq!(data, vec![1, 2, 3, 4]);
520 assert_eq!(shape, vec![2, 2]);
521 }
522 _ => panic!("Expected UintTensor"),
523 }
524
525 let int_1d = ModelSpecificValue::int_1d(vec![1, 2, 3]);
526 match int_1d {
527 ModelSpecificValue::IntTensor { data, shape } => {
528 assert_eq!(data, vec![1, 2, 3]);
529 assert_eq!(shape, vec![3]);
530 }
531 _ => panic!("Expected IntTensor"),
532 }
533
534 let int_2d = ModelSpecificValue::int_2d(vec![1, 2, 3, 4], 2, 2);
535 match int_2d {
536 ModelSpecificValue::IntTensor { data, shape } => {
537 assert_eq!(data, vec![1, 2, 3, 4]);
538 assert_eq!(shape, vec![2, 2]);
539 }
540 _ => panic!("Expected IntTensor"),
541 }
542 }
543
544 #[test]
545 fn test_pixel_values_flat() {
546 let mut pixel_values = Array4::<f32>::zeros((1, 1, 2, 2));
547 pixel_values[[0, 0, 0, 0]] = 1.0;
548 pixel_values[[0, 0, 0, 1]] = 2.0;
549 pixel_values[[0, 0, 1, 0]] = 3.0;
550 pixel_values[[0, 0, 1, 1]] = 4.0;
551
552 let images = PreprocessedImages::new(pixel_values, vec![4], vec![(2, 2)]);
553 let flat = images.pixel_values_flat();
554
555 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
556 }
557
558 #[test]
559 fn test_registry_with_defaults() {
560 let registry = ImageProcessorRegistry::with_defaults();
561
562 assert!(registry.find("llava-hf/llava-1.5-7b-hf", None).is_some());
564 assert!(registry.find("liuhaotian/llava-v1.5-7b", None).is_some());
565
566 assert!(registry
568 .find("llava-hf/llava-v1.6-mistral-7b-hf", None)
569 .is_some());
570 assert!(registry
571 .find("lmms-lab/llava-next-interleave-qwen-7b", None)
572 .is_some());
573
574 let processor = registry.find("llava-hf/llava-1.5-7b-hf", None).unwrap();
576 assert_eq!(processor.model_name(), "llava");
577 }
578
579 #[test]
580 fn test_registry_find() {
581 let mut registry = ImageProcessorRegistry::new();
582
583 registry.register("test-model", Box::new(LlavaProcessor::new()));
585
586 assert!(registry.find("test-model-7b", None).is_some());
587 assert!(registry.find("TEST-MODEL", None).is_some());
588 assert!(registry.find("other-model", None).is_none());
589 }
590
591 #[test]
592 fn test_registry_find_falls_back_to_model_type() {
593 let registry = ImageProcessorRegistry::with_defaults();
594
595 assert!(registry.find("custom-model", None).is_none());
596
597 let processor = registry
598 .find("custom-model", Some("qwen3_vl"))
599 .expect("qwen3 processor by model_type");
600 assert_eq!(processor.model_name(), "qwen3-vl");
601 }
602
603 #[test]
604 fn test_registry_find_preserves_fast_path() {
605 let registry = ImageProcessorRegistry::with_defaults();
606
607 let processor = registry
608 .find("Qwen3-VL-30B-A3B-Instruct", Some("qwen2_vl"))
609 .expect("qwen3 processor by model_id");
610 assert_eq!(processor.model_name(), "qwen3-vl");
611 }
612
613 #[test]
614 fn test_registry_find_phi3_model_type_fallback() {
615 let registry = ImageProcessorRegistry::with_defaults();
616
617 let processor = registry
618 .find("custom-model", Some("phi3_v"))
619 .expect("phi3 processor by model_type");
620 assert_eq!(processor.model_name(), "phi3-vision");
621 }
622}