Skip to main content

llm_multimodal/vision/
preprocessor_config.rs

1//! HuggingFace preprocessor_config.json parsing.
2//!
3//! This module parses the `preprocessor_config.json` files from HuggingFace model
4//! repositories, providing the configuration needed for image preprocessing.
5
6use std::collections::HashMap;
7
8use image::imageops::FilterType;
9use serde::{Deserialize, Deserializer};
10
11use super::transforms;
12
13/// Struct to represent patch_size as dict {"height": x, "width": y}
14#[derive(Debug, Clone, Deserialize, Default)]
15pub struct PatchSize {
16    pub height: Option<u32>,
17    pub width: Option<u32>,
18}
19
20/// Custom deserializer for patch_size that handles both integer and dict formats.
21/// - Integer format: `"patch_size": 16` -> PatchSize { height: 16, width: 16 }
22/// - Dict format: `"patch_size": {"height": 16, "width": 16}` -> PatchSize { height: 16, width: 16 }
23fn deserialize_patch_size<'de, D>(deserializer: D) -> Result<Option<PatchSize>, D::Error>
24where
25    D: Deserializer<'de>,
26{
27    use std::fmt;
28
29    use serde::de::{self, MapAccess, Visitor};
30
31    struct PatchSizeVisitor;
32
33    impl<'de> Visitor<'de> for PatchSizeVisitor {
34        type Value = Option<PatchSize>;
35
36        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
37            formatter.write_str("an integer, a dict with height/width, or null")
38        }
39
40        fn visit_none<E>(self) -> Result<Self::Value, E>
41        where
42            E: de::Error,
43        {
44            Ok(None)
45        }
46
47        fn visit_unit<E>(self) -> Result<Self::Value, E>
48        where
49            E: de::Error,
50        {
51            Ok(None)
52        }
53
54        fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
55        where
56            E: de::Error,
57        {
58            let v = value as u32;
59            Ok(Some(PatchSize {
60                height: Some(v),
61                width: Some(v),
62            }))
63        }
64
65        fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
66        where
67            E: de::Error,
68        {
69            let v = value as u32;
70            Ok(Some(PatchSize {
71                height: Some(v),
72                width: Some(v),
73            }))
74        }
75
76        fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
77        where
78            M: MapAccess<'de>,
79        {
80            let mut height = None;
81            let mut width = None;
82
83            while let Some(key) = map.next_key::<String>()? {
84                match key.as_str() {
85                    "height" => height = Some(map.next_value::<u32>()?),
86                    "width" => width = Some(map.next_value::<u32>()?),
87                    _ => {
88                        let _ = map.next_value::<de::IgnoredAny>()?;
89                    }
90                }
91            }
92
93            Ok(Some(PatchSize { height, width }))
94        }
95    }
96
97    deserializer.deserialize_any(PatchSizeVisitor)
98}
99
100/// HuggingFace preprocessor_config.json structure.
101///
102/// This struct captures the common fields across different vision model processors.
103/// Model-specific fields are accessed via the flexible `extra` field.
104#[derive(Debug, Clone, Deserialize, Default)]
105pub struct PreProcessorConfig {
106    /// Processor class name (e.g., "CLIPImageProcessor", "Qwen2VLImageProcessor")
107    #[serde(default)]
108    pub image_processor_type: Option<String>,
109
110    /// Whether to convert to RGB
111    #[serde(default)]
112    pub do_convert_rgb: Option<bool>,
113
114    /// Whether to normalize with mean/std
115    #[serde(default)]
116    pub do_normalize: Option<bool>,
117
118    /// Whether to pad images
119    #[serde(default)]
120    pub do_pad: Option<bool>,
121
122    /// Whether to rescale pixel values (typically by 1/255)
123    #[serde(default)]
124    pub do_rescale: Option<bool>,
125
126    /// Whether to resize images
127    #[serde(default)]
128    pub do_resize: Option<bool>,
129
130    /// Whether to center crop after resizing
131    #[serde(default)]
132    pub do_center_crop: Option<bool>,
133
134    /// Per-channel normalization mean
135    #[serde(default, alias = "norm_mean")]
136    pub image_mean: Option<Vec<f64>>,
137
138    /// Per-channel normalization std
139    #[serde(default, alias = "norm_std")]
140    pub image_std: Option<Vec<f64>>,
141
142    /// Rescale factor (typically 1/255 = 0.00392156862745098)
143    #[serde(default)]
144    pub rescale_factor: Option<f64>,
145
146    /// PIL resampling filter enum (0=Nearest, 1=Lanczos, 2=Bilinear, 3=Bicubic)
147    #[serde(default, alias = "resample")]
148    pub resampling: Option<usize>,
149
150    /// Target size for resizing
151    /// Can be {"height": H, "width": W} or {"shortest_edge": S}
152    #[serde(default)]
153    pub size: Option<HashMap<String, u32>>,
154
155    /// Target size for center cropping
156    #[serde(default)]
157    pub crop_size: Option<HashMap<String, u32>>,
158
159    // =====================
160    // Model-specific fields
161    // =====================
162    /// Vision encoder patch size (typically 14 or 16)
163    /// Can be an integer or a dict {"height": x, "width": y}
164    #[serde(default, deserialize_with = "deserialize_patch_size")]
165    pub patch_size: Option<PatchSize>,
166
167    /// Qwen-VL: merge size for token reduction
168    #[serde(default)]
169    pub merge_size: Option<usize>,
170
171    /// Qwen-VL: minimum total pixels
172    #[serde(default)]
173    pub min_pixels: Option<usize>,
174
175    /// Qwen-VL: maximum total pixels
176    #[serde(default)]
177    pub max_pixels: Option<usize>,
178
179    /// Qwen-VL: temporal patch size for video
180    #[serde(default)]
181    pub temporal_patch_size: Option<usize>,
182
183    /// Phi3-Vision: number of image crops
184    #[serde(default)]
185    pub num_crops: Option<usize>,
186
187    /// Phi4-Vision: dynamic HD max crops
188    #[serde(default)]
189    pub dynamic_hd: Option<usize>,
190
191    /// LLaMA-Vision: maximum image tiles
192    #[serde(default)]
193    pub max_image_tiles: Option<usize>,
194
195    /// Fixed number of image tokens (some models use this)
196    #[serde(default)]
197    pub num_img_tokens: Option<usize>,
198
199    // =====================
200    // Special tokens
201    // =====================
202    /// Image start token
203    #[serde(default)]
204    pub im_start_token: Option<String>,
205
206    /// Image end token
207    #[serde(default)]
208    pub im_end_token: Option<String>,
209
210    /// Slice start token (for multi-crop)
211    #[serde(default)]
212    pub slice_start_token: Option<String>,
213
214    /// Slice end token
215    #[serde(default)]
216    pub slice_end_token: Option<String>,
217
218    /// Vision start token (alternative naming)
219    #[serde(default)]
220    pub vision_start_token: Option<String>,
221
222    /// Vision end token
223    #[serde(default)]
224    pub vision_end_token: Option<String>,
225
226    /// Catch-all for model-specific fields not explicitly defined
227    #[serde(flatten)]
228    pub extra: HashMap<String, serde_json::Value>,
229}
230
231impl PreProcessorConfig {
232    /// Parse from JSON string.
233    ///
234    /// Handles both standard HuggingFace format (top-level fields) and Kimi-K2.5's
235    /// nested format where values are under `media_proc_cfg`.
236    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
237        let raw: serde_json::Value = serde_json::from_str(json)?;
238        Self::from_value(raw)
239    }
240
241    /// Parse from JSON value.
242    ///
243    /// Handles both standard HuggingFace format (top-level fields) and Kimi-K2.5's
244    /// nested format where values are under `media_proc_cfg`.
245    pub fn from_value(value: serde_json::Value) -> Result<Self, serde_json::Error> {
246        let mut config: Self = serde_json::from_value(value.clone())?;
247        Self::apply_nested_media_cfg(&mut config, &value);
248        Ok(config)
249    }
250
251    /// Extract values from nested `media_proc_cfg` (used by Kimi-K2.5 and
252    /// similar models) when top-level fields are missing.
253    fn apply_nested_media_cfg(config: &mut Self, raw: &serde_json::Value) {
254        let Some(media_cfg) = raw.get("media_proc_cfg") else {
255            return;
256        };
257        if config.image_mean.is_none() {
258            config.image_mean = media_cfg
259                .get("image_mean")
260                .and_then(|v| serde_json::from_value(v.clone()).ok());
261        }
262        if config.image_std.is_none() {
263            config.image_std = media_cfg
264                .get("image_std")
265                .and_then(|v| serde_json::from_value(v.clone()).ok());
266        }
267        if config.patch_size.is_none() {
268            config.patch_size = media_cfg.get("patch_size").and_then(|v| {
269                v.as_u64().map(|ps| PatchSize {
270                    height: Some(ps as u32),
271                    width: Some(ps as u32),
272                })
273            });
274        }
275        if config.merge_size.is_none() {
276            config.merge_size = media_cfg
277                .get("merge_kernel_size")
278                .and_then(|v| v.as_u64())
279                .map(|v| v as usize);
280        }
281        // Also extract Kimi-specific limits into the extra map
282        // so processors can read them via get_extra()
283        for key in ["in_patch_limit", "patch_limit_on_one_side"] {
284            if !config.extra.contains_key(key) {
285                if let Some(v) = media_cfg.get(key) {
286                    config.extra.insert(key.to_string(), v.clone());
287                }
288            }
289        }
290    }
291
292    /// Get patch size as a simple usize.
293    ///
294    /// Returns the height value from PatchSize if available, falling back to provided default.
295    pub fn get_patch_size(&self, default: usize) -> usize {
296        self.patch_size
297            .as_ref()
298            .and_then(|p| p.height)
299            .map(|h| h as usize)
300            .unwrap_or(default)
301    }
302
303    /// Get image mean as fixed array, with fallback to CLIP defaults.
304    pub fn get_image_mean(&self) -> [f64; 3] {
305        self.image_mean
306            .as_ref()
307            .and_then(|v| {
308                if v.len() >= 3 {
309                    Some([v[0], v[1], v[2]])
310                } else {
311                    None
312                }
313            })
314            .unwrap_or(Self::CLIP_MEAN)
315    }
316
317    /// Get image std as fixed array, with fallback to CLIP defaults.
318    pub fn get_image_std(&self) -> [f64; 3] {
319        self.image_std
320            .as_ref()
321            .and_then(|v| {
322                if v.len() >= 3 {
323                    Some([v[0], v[1], v[2]])
324                } else {
325                    None
326                }
327            })
328            .unwrap_or(Self::CLIP_STD)
329    }
330
331    /// Get target size from various config formats.
332    ///
333    /// Handles both `{"height": H, "width": W}` and `{"shortest_edge": S}` formats.
334    /// Returns (height, width).
335    pub fn get_target_size(&self) -> Option<(u32, u32)> {
336        self.size.as_ref().map(|s| {
337            // Try explicit height/width first
338            let h = s
339                .get("height")
340                .or_else(|| s.get("shortest_edge"))
341                .copied()
342                .unwrap_or(224);
343            let w = s
344                .get("width")
345                .or_else(|| s.get("shortest_edge"))
346                .copied()
347                .unwrap_or(224);
348            (h, w)
349        })
350    }
351
352    /// Get crop size.
353    ///
354    /// Returns (height, width).
355    pub fn get_crop_size(&self) -> Option<(u32, u32)> {
356        self.crop_size.as_ref().map(|s| {
357            let h = s.get("height").copied().unwrap_or(224);
358            let w = s.get("width").copied().unwrap_or(224);
359            (h, w)
360        })
361    }
362
363    /// Get the interpolation filter for resizing.
364    pub fn get_filter(&self) -> FilterType {
365        transforms::pil_to_filter(self.resampling)
366    }
367
368    /// Check if normalization should be applied.
369    pub fn should_normalize(&self) -> bool {
370        self.do_normalize.unwrap_or(true)
371    }
372
373    /// Check if rescaling should be applied.
374    pub fn should_rescale(&self) -> bool {
375        self.do_rescale.unwrap_or(false)
376    }
377
378    /// Check if resizing should be applied.
379    pub fn should_resize(&self) -> bool {
380        self.do_resize.unwrap_or(true)
381    }
382
383    /// Check if center cropping should be applied.
384    pub fn should_center_crop(&self) -> bool {
385        self.do_center_crop.unwrap_or(false)
386    }
387
388    /// Get rescale factor with default.
389    pub fn get_rescale_factor(&self) -> f64 {
390        self.rescale_factor.unwrap_or(1.0 / 255.0)
391    }
392
393    /// Get a typed extra field.
394    pub fn get_extra<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
395        self.extra
396            .get(key)
397            .and_then(|v| serde_json::from_value(v.clone()).ok())
398    }
399
400    // Common default values
401    pub const CLIP_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
402    pub const CLIP_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
403
404    pub const IMAGENET_MEAN: [f64; 3] = [0.485, 0.456, 0.406];
405    pub const IMAGENET_STD: [f64; 3] = [0.229, 0.224, 0.225];
406
407    pub const SIGLIP_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
408    pub const SIGLIP_STD: [f64; 3] = [0.5, 0.5, 0.5];
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_parse_clip_config() {
417        let json = r#"{
418            "do_center_crop": true,
419            "do_normalize": true,
420            "do_resize": true,
421            "image_mean": [0.48145466, 0.4578275, 0.40821073],
422            "image_std": [0.26862954, 0.26130258, 0.27577711],
423            "resample": 3,
424            "size": {"shortest_edge": 224}
425        }"#;
426
427        let config = PreProcessorConfig::from_json(json).unwrap();
428
429        assert!(config.should_normalize());
430        assert!(config.should_center_crop());
431        assert!(config.should_resize());
432        assert_eq!(config.resampling, Some(3));
433
434        let (h, w) = config.get_target_size().unwrap();
435        assert_eq!(h, 224);
436        assert_eq!(w, 224);
437
438        let mean = config.get_image_mean();
439        assert!((mean[0] - 0.48145466).abs() < 1e-6);
440    }
441
442    #[test]
443    fn test_parse_qwen_vl_config() {
444        let json = r#"{
445            "do_normalize": true,
446            "do_rescale": true,
447            "do_resize": true,
448            "image_mean": [0.48145466, 0.4578275, 0.40821073],
449            "image_std": [0.26862954, 0.26130258, 0.27577711],
450            "min_pixels": 200704,
451            "max_pixels": 1003520,
452            "patch_size": 14,
453            "merge_size": 2,
454            "temporal_patch_size": 2,
455            "rescale_factor": 0.00392156862745098
456        }"#;
457
458        let config = PreProcessorConfig::from_json(json).unwrap();
459
460        assert_eq!(config.min_pixels, Some(200704));
461        assert_eq!(config.max_pixels, Some(1003520));
462        assert_eq!(config.get_patch_size(0), 14);
463        assert_eq!(config.merge_size, Some(2));
464        assert!((config.get_rescale_factor() - 1.0 / 255.0).abs() < 1e-10);
465    }
466
467    #[test]
468    fn test_parse_size_formats() {
469        // Height/width format
470        let json1 = r#"{"size": {"height": 336, "width": 336}}"#;
471        let config1 = PreProcessorConfig::from_json(json1).unwrap();
472        assert_eq!(config1.get_target_size(), Some((336, 336)));
473
474        // Shortest edge format
475        let json2 = r#"{"size": {"shortest_edge": 224}}"#;
476        let config2 = PreProcessorConfig::from_json(json2).unwrap();
477        assert_eq!(config2.get_target_size(), Some((224, 224)));
478    }
479
480    #[test]
481    fn test_defaults() {
482        let config = PreProcessorConfig::default();
483
484        // Should use CLIP defaults when not specified
485        let mean = config.get_image_mean();
486        assert!((mean[0] - PreProcessorConfig::CLIP_MEAN[0]).abs() < 1e-6);
487
488        // Default behaviors
489        assert!(config.should_normalize()); // true by default
490        assert!(!config.should_rescale()); // false by default
491        assert!(config.should_resize()); // true by default
492        assert!(!config.should_center_crop()); // false by default
493    }
494
495    #[test]
496    fn test_filter_conversion() {
497        let json = r#"{"resampling": 3}"#;
498        let config = PreProcessorConfig::from_json(json).unwrap();
499        assert!(matches!(config.get_filter(), FilterType::CatmullRom));
500    }
501
502    #[test]
503    fn test_extra_fields() {
504        let json = r#"{
505            "custom_field": 42,
506            "nested": {"foo": "bar"}
507        }"#;
508
509        let config = PreProcessorConfig::from_json(json).unwrap();
510
511        let custom: Option<i32> = config.get_extra("custom_field");
512        assert_eq!(custom, Some(42));
513
514        let nested: Option<HashMap<String, String>> = config.get_extra("nested");
515        assert_eq!(
516            nested.as_ref().unwrap().get("foo"),
517            Some(&"bar".to_string())
518        );
519    }
520
521    #[test]
522    fn test_parse_kimi_nested_media_proc_cfg() {
523        let json = r#"{
524            "auto_map": {
525                "AutoProcessor": "kimi_k25_processor.KimiK25Processor"
526            },
527            "media_proc_cfg": {
528                "in_patch_limit": 16384,
529                "patch_size": 14,
530                "image_mean": [0.5, 0.5, 0.5],
531                "image_std": [0.5, 0.5, 0.5],
532                "merge_kernel_size": 2,
533                "patch_limit_on_one_side": 512
534            }
535        }"#;
536
537        let config = PreProcessorConfig::from_json(json).unwrap();
538
539        // image_mean/std should be extracted from media_proc_cfg
540        let mean = config.get_image_mean();
541        assert!((mean[0] - 0.5).abs() < 1e-6);
542        assert!((mean[1] - 0.5).abs() < 1e-6);
543        assert!((mean[2] - 0.5).abs() < 1e-6);
544
545        let std = config.get_image_std();
546        assert!((std[0] - 0.5).abs() < 1e-6);
547
548        assert_eq!(config.get_patch_size(0), 14);
549        assert_eq!(config.merge_size, Some(2));
550    }
551}