Skip to main content

ferrum_models/
image_processor.rs

1//! Image preprocessing for CLIP models.
2//!
3//! Load → resize → normalize → [1, 3, H, W] tensor.
4
5use candle_core::{DType, Device, Tensor};
6use ferrum_types::{FerrumError, Result};
7
8/// CLIP image processor with configurable size and normalization.
9pub struct ClipImageProcessor {
10    image_size: usize,
11    mean: [f32; 3],
12    std: [f32; 3],
13}
14
15impl ClipImageProcessor {
16    /// Standard CLIP normalization (ImageNet stats).
17    pub fn new(image_size: usize) -> Self {
18        Self {
19            image_size,
20            mean: [0.48145466, 0.4578275, 0.40821073],
21            std: [0.26862954, 0.26130258, 0.27577711],
22        }
23    }
24
25    /// Load image from file path → preprocessed tensor.
26    pub fn process_path(&self, path: &str, device: &Device) -> Result<Tensor> {
27        let img =
28            image::open(path).map_err(|e| FerrumError::model(format!("image load {path}: {e}")))?;
29        self.process_image(img, device)
30    }
31
32    /// Decode base64 image data → preprocessed tensor.
33    pub fn process_base64(&self, data: &str, device: &Device) -> Result<Tensor> {
34        use base64::Engine;
35        // Strip optional data URI prefix
36        let raw = if let Some(pos) = data.find(",") {
37            &data[pos + 1..]
38        } else {
39            data
40        };
41        let bytes = base64::engine::general_purpose::STANDARD
42            .decode(raw)
43            .map_err(|e| FerrumError::model(format!("base64 decode: {e}")))?;
44        let img = image::load_from_memory(&bytes)
45            .map_err(|e| FerrumError::model(format!("image decode: {e}")))?;
46        self.process_image(img, device)
47    }
48
49    /// Core pipeline: DynamicImage → resize → normalize → [1, 3, H, W] tensor.
50    fn process_image(&self, img: image::DynamicImage, device: &Device) -> Result<Tensor> {
51        let size = self.image_size as u32;
52        let img = img.resize_exact(size, size, image::imageops::FilterType::Triangle);
53        let img = img.to_rgb8();
54
55        let (w, h) = (img.width() as usize, img.height() as usize);
56        let raw: Vec<f32> = img
57            .into_raw()
58            .into_iter()
59            .map(|p| p as f32 / 255.0)
60            .collect();
61
62        // [H, W, 3] → [3, H, W] with normalization
63        let mut chw = vec![0f32; 3 * h * w];
64        for c in 0..3 {
65            for i in 0..h * w {
66                chw[c * h * w + i] = (raw[i * 3 + c] - self.mean[c]) / self.std[c];
67            }
68        }
69
70        Tensor::from_vec(chw, (1, 3, h, w), device)
71            .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
72            .to_dtype(DType::F32)
73            .map_err(|e| FerrumError::model(format!("dtype: {e}")))
74    }
75}