1use std::collections::HashMap;
7
8use image::imageops::FilterType;
9use serde::{Deserialize, Deserializer};
10
11use super::transforms;
12
13#[derive(Debug, Clone, Deserialize, Default)]
15pub struct PatchSize {
16 pub height: Option<u32>,
17 pub width: Option<u32>,
18}
19
20fn deserialize_patch_size<'de, D>(deserializer: D) -> Result<Option<PatchSize>, D::Error>
24where
25 D: Deserializer<'de>,
26{
27 use std::fmt;
28
29 use serde::de::{self, MapAccess, Visitor};
30
31 struct PatchSizeVisitor;
32
33 impl<'de> Visitor<'de> for PatchSizeVisitor {
34 type Value = Option<PatchSize>;
35
36 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
37 formatter.write_str("an integer, a dict with height/width, or null")
38 }
39
40 fn visit_none<E>(self) -> Result<Self::Value, E>
41 where
42 E: de::Error,
43 {
44 Ok(None)
45 }
46
47 fn visit_unit<E>(self) -> Result<Self::Value, E>
48 where
49 E: de::Error,
50 {
51 Ok(None)
52 }
53
54 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
55 where
56 E: de::Error,
57 {
58 let v = value as u32;
59 Ok(Some(PatchSize {
60 height: Some(v),
61 width: Some(v),
62 }))
63 }
64
65 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
66 where
67 E: de::Error,
68 {
69 let v = value as u32;
70 Ok(Some(PatchSize {
71 height: Some(v),
72 width: Some(v),
73 }))
74 }
75
76 fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
77 where
78 M: MapAccess<'de>,
79 {
80 let mut height = None;
81 let mut width = None;
82
83 while let Some(key) = map.next_key::<String>()? {
84 match key.as_str() {
85 "height" => height = Some(map.next_value::<u32>()?),
86 "width" => width = Some(map.next_value::<u32>()?),
87 _ => {
88 let _ = map.next_value::<de::IgnoredAny>()?;
89 }
90 }
91 }
92
93 Ok(Some(PatchSize { height, width }))
94 }
95 }
96
97 deserializer.deserialize_any(PatchSizeVisitor)
98}
99
100#[derive(Debug, Clone, Deserialize, Default)]
105pub struct PreProcessorConfig {
106 #[serde(default)]
108 pub image_processor_type: Option<String>,
109
110 #[serde(default)]
112 pub do_convert_rgb: Option<bool>,
113
114 #[serde(default)]
116 pub do_normalize: Option<bool>,
117
118 #[serde(default)]
120 pub do_pad: Option<bool>,
121
122 #[serde(default)]
124 pub do_rescale: Option<bool>,
125
126 #[serde(default)]
128 pub do_resize: Option<bool>,
129
130 #[serde(default)]
132 pub do_center_crop: Option<bool>,
133
134 #[serde(default, alias = "norm_mean")]
136 pub image_mean: Option<Vec<f64>>,
137
138 #[serde(default, alias = "norm_std")]
140 pub image_std: Option<Vec<f64>>,
141
142 #[serde(default)]
144 pub rescale_factor: Option<f64>,
145
146 #[serde(default, alias = "resample")]
148 pub resampling: Option<usize>,
149
150 #[serde(default)]
153 pub size: Option<HashMap<String, u32>>,
154
155 #[serde(default)]
157 pub crop_size: Option<HashMap<String, u32>>,
158
159 #[serde(default, deserialize_with = "deserialize_patch_size")]
165 pub patch_size: Option<PatchSize>,
166
167 #[serde(default)]
169 pub merge_size: Option<usize>,
170
171 #[serde(default)]
173 pub min_pixels: Option<usize>,
174
175 #[serde(default)]
177 pub max_pixels: Option<usize>,
178
179 #[serde(default)]
181 pub temporal_patch_size: Option<usize>,
182
183 #[serde(default)]
185 pub num_crops: Option<usize>,
186
187 #[serde(default)]
189 pub dynamic_hd: Option<usize>,
190
191 #[serde(default)]
193 pub max_image_tiles: Option<usize>,
194
195 #[serde(default)]
197 pub num_img_tokens: Option<usize>,
198
199 #[serde(default)]
204 pub im_start_token: Option<String>,
205
206 #[serde(default)]
208 pub im_end_token: Option<String>,
209
210 #[serde(default)]
212 pub slice_start_token: Option<String>,
213
214 #[serde(default)]
216 pub slice_end_token: Option<String>,
217
218 #[serde(default)]
220 pub vision_start_token: Option<String>,
221
222 #[serde(default)]
224 pub vision_end_token: Option<String>,
225
226 #[serde(flatten)]
228 pub extra: HashMap<String, serde_json::Value>,
229}
230
231impl PreProcessorConfig {
232 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
237 let raw: serde_json::Value = serde_json::from_str(json)?;
238 Self::from_value(raw)
239 }
240
241 pub fn from_value(value: serde_json::Value) -> Result<Self, serde_json::Error> {
246 let mut config: Self = serde_json::from_value(value.clone())?;
247 Self::apply_nested_media_cfg(&mut config, &value);
248 Ok(config)
249 }
250
251 fn apply_nested_media_cfg(config: &mut Self, raw: &serde_json::Value) {
254 let Some(media_cfg) = raw.get("media_proc_cfg") else {
255 return;
256 };
257 if config.image_mean.is_none() {
258 config.image_mean = media_cfg
259 .get("image_mean")
260 .and_then(|v| serde_json::from_value(v.clone()).ok());
261 }
262 if config.image_std.is_none() {
263 config.image_std = media_cfg
264 .get("image_std")
265 .and_then(|v| serde_json::from_value(v.clone()).ok());
266 }
267 if config.patch_size.is_none() {
268 config.patch_size = media_cfg.get("patch_size").and_then(|v| {
269 v.as_u64().map(|ps| PatchSize {
270 height: Some(ps as u32),
271 width: Some(ps as u32),
272 })
273 });
274 }
275 if config.merge_size.is_none() {
276 config.merge_size = media_cfg
277 .get("merge_kernel_size")
278 .and_then(|v| v.as_u64())
279 .map(|v| v as usize);
280 }
281 for key in ["in_patch_limit", "patch_limit_on_one_side"] {
284 if !config.extra.contains_key(key) {
285 if let Some(v) = media_cfg.get(key) {
286 config.extra.insert(key.to_string(), v.clone());
287 }
288 }
289 }
290 }
291
292 pub fn get_patch_size(&self, default: usize) -> usize {
296 self.patch_size
297 .as_ref()
298 .and_then(|p| p.height)
299 .map(|h| h as usize)
300 .unwrap_or(default)
301 }
302
303 pub fn get_image_mean(&self) -> [f64; 3] {
305 self.image_mean
306 .as_ref()
307 .and_then(|v| {
308 if v.len() >= 3 {
309 Some([v[0], v[1], v[2]])
310 } else {
311 None
312 }
313 })
314 .unwrap_or(Self::CLIP_MEAN)
315 }
316
317 pub fn get_image_std(&self) -> [f64; 3] {
319 self.image_std
320 .as_ref()
321 .and_then(|v| {
322 if v.len() >= 3 {
323 Some([v[0], v[1], v[2]])
324 } else {
325 None
326 }
327 })
328 .unwrap_or(Self::CLIP_STD)
329 }
330
331 pub fn get_target_size(&self) -> Option<(u32, u32)> {
336 self.size.as_ref().map(|s| {
337 let h = s
339 .get("height")
340 .or_else(|| s.get("shortest_edge"))
341 .copied()
342 .unwrap_or(224);
343 let w = s
344 .get("width")
345 .or_else(|| s.get("shortest_edge"))
346 .copied()
347 .unwrap_or(224);
348 (h, w)
349 })
350 }
351
352 pub fn get_crop_size(&self) -> Option<(u32, u32)> {
356 self.crop_size.as_ref().map(|s| {
357 let h = s.get("height").copied().unwrap_or(224);
358 let w = s.get("width").copied().unwrap_or(224);
359 (h, w)
360 })
361 }
362
363 pub fn get_filter(&self) -> FilterType {
365 transforms::pil_to_filter(self.resampling)
366 }
367
368 pub fn should_normalize(&self) -> bool {
370 self.do_normalize.unwrap_or(true)
371 }
372
373 pub fn should_rescale(&self) -> bool {
375 self.do_rescale.unwrap_or(false)
376 }
377
378 pub fn should_resize(&self) -> bool {
380 self.do_resize.unwrap_or(true)
381 }
382
383 pub fn should_center_crop(&self) -> bool {
385 self.do_center_crop.unwrap_or(false)
386 }
387
388 pub fn get_rescale_factor(&self) -> f64 {
390 self.rescale_factor.unwrap_or(1.0 / 255.0)
391 }
392
393 pub fn get_extra<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
395 self.extra
396 .get(key)
397 .and_then(|v| serde_json::from_value(v.clone()).ok())
398 }
399
400 pub const CLIP_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
402 pub const CLIP_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
403
404 pub const IMAGENET_MEAN: [f64; 3] = [0.485, 0.456, 0.406];
405 pub const IMAGENET_STD: [f64; 3] = [0.229, 0.224, 0.225];
406
407 pub const SIGLIP_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
408 pub const SIGLIP_STD: [f64; 3] = [0.5, 0.5, 0.5];
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_parse_clip_config() {
417 let json = r#"{
418 "do_center_crop": true,
419 "do_normalize": true,
420 "do_resize": true,
421 "image_mean": [0.48145466, 0.4578275, 0.40821073],
422 "image_std": [0.26862954, 0.26130258, 0.27577711],
423 "resample": 3,
424 "size": {"shortest_edge": 224}
425 }"#;
426
427 let config = PreProcessorConfig::from_json(json).unwrap();
428
429 assert!(config.should_normalize());
430 assert!(config.should_center_crop());
431 assert!(config.should_resize());
432 assert_eq!(config.resampling, Some(3));
433
434 let (h, w) = config.get_target_size().unwrap();
435 assert_eq!(h, 224);
436 assert_eq!(w, 224);
437
438 let mean = config.get_image_mean();
439 assert!((mean[0] - 0.48145466).abs() < 1e-6);
440 }
441
442 #[test]
443 fn test_parse_qwen_vl_config() {
444 let json = r#"{
445 "do_normalize": true,
446 "do_rescale": true,
447 "do_resize": true,
448 "image_mean": [0.48145466, 0.4578275, 0.40821073],
449 "image_std": [0.26862954, 0.26130258, 0.27577711],
450 "min_pixels": 200704,
451 "max_pixels": 1003520,
452 "patch_size": 14,
453 "merge_size": 2,
454 "temporal_patch_size": 2,
455 "rescale_factor": 0.00392156862745098
456 }"#;
457
458 let config = PreProcessorConfig::from_json(json).unwrap();
459
460 assert_eq!(config.min_pixels, Some(200704));
461 assert_eq!(config.max_pixels, Some(1003520));
462 assert_eq!(config.get_patch_size(0), 14);
463 assert_eq!(config.merge_size, Some(2));
464 assert!((config.get_rescale_factor() - 1.0 / 255.0).abs() < 1e-10);
465 }
466
467 #[test]
468 fn test_parse_size_formats() {
469 let json1 = r#"{"size": {"height": 336, "width": 336}}"#;
471 let config1 = PreProcessorConfig::from_json(json1).unwrap();
472 assert_eq!(config1.get_target_size(), Some((336, 336)));
473
474 let json2 = r#"{"size": {"shortest_edge": 224}}"#;
476 let config2 = PreProcessorConfig::from_json(json2).unwrap();
477 assert_eq!(config2.get_target_size(), Some((224, 224)));
478 }
479
480 #[test]
481 fn test_defaults() {
482 let config = PreProcessorConfig::default();
483
484 let mean = config.get_image_mean();
486 assert!((mean[0] - PreProcessorConfig::CLIP_MEAN[0]).abs() < 1e-6);
487
488 assert!(config.should_normalize()); assert!(!config.should_rescale()); assert!(config.should_resize()); assert!(!config.should_center_crop()); }
494
495 #[test]
496 fn test_filter_conversion() {
497 let json = r#"{"resampling": 3}"#;
498 let config = PreProcessorConfig::from_json(json).unwrap();
499 assert!(matches!(config.get_filter(), FilterType::CatmullRom));
500 }
501
502 #[test]
503 fn test_extra_fields() {
504 let json = r#"{
505 "custom_field": 42,
506 "nested": {"foo": "bar"}
507 }"#;
508
509 let config = PreProcessorConfig::from_json(json).unwrap();
510
511 let custom: Option<i32> = config.get_extra("custom_field");
512 assert_eq!(custom, Some(42));
513
514 let nested: Option<HashMap<String, String>> = config.get_extra("nested");
515 assert_eq!(
516 nested.as_ref().unwrap().get("foo"),
517 Some(&"bar".to_string())
518 );
519 }
520
521 #[test]
522 fn test_parse_kimi_nested_media_proc_cfg() {
523 let json = r#"{
524 "auto_map": {
525 "AutoProcessor": "kimi_k25_processor.KimiK25Processor"
526 },
527 "media_proc_cfg": {
528 "in_patch_limit": 16384,
529 "patch_size": 14,
530 "image_mean": [0.5, 0.5, 0.5],
531 "image_std": [0.5, 0.5, 0.5],
532 "merge_kernel_size": 2,
533 "patch_limit_on_one_side": 512
534 }
535 }"#;
536
537 let config = PreProcessorConfig::from_json(json).unwrap();
538
539 let mean = config.get_image_mean();
541 assert!((mean[0] - 0.5).abs() < 1e-6);
542 assert!((mean[1] - 0.5).abs() < 1e-6);
543 assert!((mean[2] - 0.5).abs() < 1e-6);
544
545 let std = config.get_image_std();
546 assert!((std[0] - 0.5).abs() < 1e-6);
547
548 assert_eq!(config.get_patch_size(0), 14);
549 assert_eq!(config.merge_size, Some(2));
550 }
551}