Skip to main content

ocrs/
preprocess.rs

1use std::fmt::Debug;
2
3use rten_tensor::prelude::*;
4use rten_tensor::{NdTensor, NdTensorView};
5use thiserror::Error;
6
7/// View of an image's pixels, in either (height, width, channels) or (channels,
8/// height, width) order.
9pub enum ImagePixels<'a> {
10    /// Pixel values in the range [0, 1]
11    Floats(NdTensorView<'a, f32, 3>),
12    /// Pixel values in the range [0, 255]
13    Bytes(NdTensorView<'a, u8, 3>),
14}
15
16impl<'a> From<NdTensorView<'a, f32, 3>> for ImagePixels<'a> {
17    fn from(value: NdTensorView<'a, f32, 3>) -> Self {
18        ImagePixels::Floats(value)
19    }
20}
21
22impl<'a> From<NdTensorView<'a, u8, 3>> for ImagePixels<'a> {
23    fn from(value: NdTensorView<'a, u8, 3>) -> Self {
24        ImagePixels::Bytes(value)
25    }
26}
27
28impl ImagePixels<'_> {
29    fn shape(&self) -> [usize; 3] {
30        match self {
31            ImagePixels::Floats(f) => f.shape(),
32            ImagePixels::Bytes(b) => b.shape(),
33        }
34    }
35}
36
37/// Errors that can occur when creating an [ImageSource].
38#[derive(Error, Clone, Debug, PartialEq)]
39pub enum ImageSourceError {
40    /// The image channel count is not 1 (greyscale), 3 (RGB) or 4 (RGBA).
41    #[error("channel count is not 1, 3 or 4")]
42    UnsupportedChannelCount,
43    /// The image data length is not a multiple of the channel size.
44    #[error("data length is not a multiple of `width * height`")]
45    InvalidDataLength,
46}
47
48/// Specifies the order in which pixels are laid out in an image tensor.
49#[derive(Copy, Clone, Debug, PartialEq)]
50pub enum DimOrder {
51    /// Channels last order. This is the order used by the
52    /// [image](https://github.com/image-rs/image) crate and HTML Canvas APIs.
53    Hwc,
54    /// Channels first order. This is the order used by many machine-learning
55    /// libraries for image tensors.
56    Chw,
57}
58
59/// View of an image, for use with
60/// [OcrEngine::prepare_input](crate::OcrEngine::prepare_input).
61pub struct ImageSource<'a> {
62    data: ImagePixels<'a>,
63    order: DimOrder,
64}
65
66impl<'a> ImageSource<'a> {
67    /// Create an image source from a buffer of pixels in HWC order.
68    ///
69    /// An image loaded using the `image` crate can be converted to an
70    /// [ImageSource] using:
71    ///
72    /// ```no_run
73    /// use ocrs::ImageSource;
74    ///
75    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
76    /// let image = image::open("image.jpg")?.into_rgb8();
77    /// let img_source = ImageSource::from_bytes(image.as_raw(), image.dimensions())?;
78    /// # Ok(())
79    /// # }
80    /// ```
81    pub fn from_bytes(
82        bytes: &'a [u8],
83        dimensions: (u32, u32),
84    ) -> Result<ImageSource<'a>, ImageSourceError> {
85        let (width, height) = dimensions;
86        let channel_len = (width * height) as usize;
87
88        if channel_len == 0 {
89            return Err(ImageSourceError::UnsupportedChannelCount);
90        }
91
92        if !bytes.len().is_multiple_of(channel_len) {
93            return Err(ImageSourceError::InvalidDataLength);
94        }
95
96        let channels = bytes.len() / channel_len;
97        Self::from_tensor(
98            NdTensorView::from_data([height as usize, width as usize, channels], bytes),
99            DimOrder::Hwc,
100        )
101    }
102
103    /// Create an image source from a tensor of bytes (`u8`) or floats (`f32`),
104    /// in either channels-first (CHW) or channels-last (HWC) order.
105    pub fn from_tensor<T>(
106        data: NdTensorView<'a, T, 3>,
107        order: DimOrder,
108    ) -> Result<ImageSource<'a>, ImageSourceError>
109    where
110        NdTensorView<'a, T, 3>: Into<ImagePixels<'a>>,
111    {
112        let channels = match order {
113            DimOrder::Hwc => data.size(2),
114            DimOrder::Chw => data.size(0),
115        };
116        match channels {
117            1 | 3 | 4 => Ok(ImageSource {
118                data: data.into(),
119                order,
120            }),
121            _ => Err(ImageSourceError::UnsupportedChannelCount),
122        }
123    }
124}
125
126/// The value used to represent fully black pixels in OCR input images
127/// prepared by [prepare_image].
128pub const BLACK_VALUE: f32 = -0.5;
129
130/// Specifies the number and order of color channels in an image.
131enum Channels {
132    Grey,
133    Rgb,
134    Rgba,
135}
136
137/// Prepare an image for use with text detection and recognition models.
138///
139/// This involves:
140///
141/// - Converting the pixels to floats
142/// - Converting the color format to greyscale
143/// - Adding a bias ([BLACK_VALUE]) to the greyscale value
144///
145/// The greyscale conversion is intended to approximately match torchvision's
146/// RGB => greyscale conversion when using `torchvision.io.read_image(path,
147/// ImageReadMode.GRAY)`, which is used when training models with greyscale
148/// inputs. torchvision internally uses libpng's `png_set_rgb_to_gray`.
149pub fn prepare_image(img: ImageSource) -> NdTensor<f32, 3> {
150    match img.order {
151        DimOrder::Hwc => prepare_image_impl::<true>(img.data),
152        DimOrder::Chw => prepare_image_impl::<false>(img.data),
153    }
154}
155
156fn prepare_image_impl<const CHANS_LAST: bool>(pixels: ImagePixels) -> NdTensor<f32, 3> {
157    let n_chans = if CHANS_LAST {
158        pixels.shape()[2]
159    } else {
160        pixels.shape()[0]
161    };
162    let src_chans = match n_chans {
163        1 => Channels::Grey,
164        3 => Channels::Rgb,
165        4 => Channels::Rgba,
166        _ => panic!("expected greyscale, RGB or RGBA input image"),
167    };
168
169    // ITU BT.601 weights for RGB => luminance conversion. These match what
170    // torchvision uses. See also https://stackoverflow.com/a/596241/434243.
171    const ITU_WEIGHTS: [f32; 3] = [0.299, 0.587, 0.114];
172
173    match pixels {
174        ImagePixels::Floats(floats) => match src_chans {
175            Channels::Grey => convert_pixels::<_, 1, _, CHANS_LAST>(floats.view(), [1.]),
176            Channels::Rgb => convert_pixels::<_, 3, _, CHANS_LAST>(floats.view(), ITU_WEIGHTS),
177            Channels::Rgba => convert_pixels::<_, 4, _, CHANS_LAST>(floats.view(), ITU_WEIGHTS),
178        },
179        ImagePixels::Bytes(bytes) => {
180            // Combine the byte -> float scaling and color components into
181            // a single weight.
182            let weights = ITU_WEIGHTS.map(|w| w / 255.0);
183            match src_chans {
184                Channels::Grey => convert_pixels::<_, 1, _, CHANS_LAST>(bytes.view(), [1. / 255.]),
185                Channels::Rgb => convert_pixels::<_, 3, _, CHANS_LAST>(bytes.view(), weights),
186                Channels::Rgba => convert_pixels::<_, 4, _, CHANS_LAST>(bytes.view(), weights),
187            }
188        }
189    }
190}
191
192/// Convert pixels in an image to floats and scale by the given channel weights.
193///
194/// `PIXEL_STRIDE` is the number of elements per pixel in the input: 1 for grey,
195/// 3 for RGB or 4 for RGBA. `CHANS` is the number of color channels used from
196/// the input (1 for grey, 3 for RGB or RGBA). `CHANS_LAST` specifies the
197/// input has (height, width, chans) if true, or (chans, height, width)
198/// if false.
199///
200/// Returns a (1, H, W) tensor.
201fn convert_pixels<
202    T: AsF32,
203    const PIXEL_STRIDE: usize,
204    const CHANS: usize,
205    const CHANS_LAST: bool,
206>(
207    src: NdTensorView<T, 3>,
208    chan_weights: [f32; CHANS],
209) -> NdTensor<f32, 3> {
210    let [height, width, chans] = if CHANS_LAST {
211        src.shape()
212    } else {
213        let [c, h, w] = src.shape();
214        [h, w, c]
215    };
216    assert_eq!(chans, PIXEL_STRIDE);
217    let mut out_pixels = Vec::with_capacity(height * width);
218
219    if CHANS_LAST {
220        // For channels-last input, we can load the input in contiguous
221        // autovectorization-friendly chunks.
222
223        // We assume the input is likely contiguous, so this should be cheap.
224        let src = src.to_contiguous();
225        let (src_pixels, remainder) = src.data().as_chunks::<PIXEL_STRIDE>();
226        debug_assert!(remainder.is_empty());
227
228        out_pixels.extend(src_pixels.iter().map(|in_pixel| {
229            let mut pixel = BLACK_VALUE;
230            for c in 0..chan_weights.len() {
231                pixel += in_pixel[c].as_f32() * chan_weights[c]
232            }
233            pixel
234        }));
235    } else {
236        for y in 0..height {
237            out_pixels.extend((0..width).map(|x| {
238                let mut pixel = BLACK_VALUE;
239                for c in 0..chan_weights.len() {
240                    pixel += src[[c, y, x]].as_f32() * chan_weights[c]
241                }
242                pixel
243            }));
244        }
245    }
246
247    NdTensor::from_data([1, height, width], out_pixels)
248}
249
250/// Convert a primitive to a float using the `as` operator.
251trait AsF32: Copy {
252    fn as_f32(self) -> f32;
253}
254
255impl AsF32 for f32 {
256    fn as_f32(self) -> f32 {
257        self
258    }
259}
260
261impl AsF32 for u8 {
262    fn as_f32(self) -> f32 {
263        self as f32
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use rten_tensor::prelude::*;
270    use rten_tensor::NdTensor;
271
272    use super::{prepare_image, DimOrder, ImageSource, ImageSourceError, BLACK_VALUE};
273
274    #[test]
275    fn test_image_source_from_bytes() {
276        struct Case {
277            len: usize,
278            width: u32,
279            height: u32,
280            error: Option<ImageSourceError>,
281        }
282
283        let cases = [
284            Case {
285                len: 100,
286                width: 10,
287                height: 10,
288                error: None,
289            },
290            Case {
291                len: 50,
292                width: 10,
293                height: 10,
294                error: Some(ImageSourceError::InvalidDataLength),
295            },
296            Case {
297                len: 8 * 8 * 2,
298                width: 8,
299                height: 8,
300                error: Some(ImageSourceError::UnsupportedChannelCount),
301            },
302            Case {
303                len: 0,
304                width: 0,
305                height: 10,
306                error: Some(ImageSourceError::UnsupportedChannelCount),
307            },
308        ];
309
310        for Case {
311            len,
312            width,
313            height,
314            error,
315        } in cases
316        {
317            let data: Vec<u8> = (0u8..len as u8).collect();
318            let source = ImageSource::from_bytes(&data, (width, height));
319            assert_eq!(source.as_ref().err(), error.as_ref());
320        }
321    }
322
323    #[test]
324    fn test_image_source_from_data() {
325        struct Case {
326            shape: [usize; 3],
327            error: Option<ImageSourceError>,
328            order: DimOrder,
329        }
330
331        let cases = [
332            Case {
333                shape: [1, 5, 5],
334                error: None,
335                order: DimOrder::Chw,
336            },
337            Case {
338                shape: [1, 5, 5],
339                error: Some(ImageSourceError::UnsupportedChannelCount),
340                order: DimOrder::Hwc,
341            },
342            Case {
343                shape: [0, 5, 5],
344                error: Some(ImageSourceError::UnsupportedChannelCount),
345                order: DimOrder::Chw,
346            },
347        ];
348
349        for Case {
350            shape,
351            error,
352            order,
353        } in cases
354        {
355            let len: usize = shape.iter().product();
356            let tensor = NdTensor::<u8, 1>::arange(0, len as u8, None).into_shape(shape);
357            let source = ImageSource::from_tensor(tensor.view(), order);
358            assert_eq!(source.as_ref().err(), error.as_ref());
359        }
360    }
361
362    /// ITU BT.601 weights for RGB => luminance conversion.
363    const ITU_WEIGHTS: [f32; 3] = [0.299, 0.587, 0.114];
364
365    /// Helper to compute expected greyscale value from RGB.
366    fn expected_grey_from_rgb(r: f32, g: f32, b: f32) -> f32 {
367        BLACK_VALUE + r * ITU_WEIGHTS[0] + g * ITU_WEIGHTS[1] + b * ITU_WEIGHTS[2]
368    }
369
370    #[track_caller]
371    fn assert_close(actual: f32, expected: f32) {
372        assert!(
373            (actual - expected).abs() < 1e-5,
374            "expected {expected}, got {actual}"
375        );
376    }
377
378    #[test]
379    fn test_prepare_image_greyscale_u8() {
380        struct Case {
381            shape: [usize; 3],
382            order: DimOrder,
383        }
384
385        let cases = [
386            Case {
387                shape: [2, 2, 1],
388                order: DimOrder::Hwc,
389            },
390            Case {
391                shape: [1, 2, 2],
392                order: DimOrder::Chw,
393            },
394        ];
395
396        for Case { shape, order } in cases {
397            let data: Vec<u8> = vec![0, 128, 255, 64];
398            let tensor = NdTensor::from_data(shape, data);
399            let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
400
401            let result = prepare_image(source);
402
403            assert_eq!(result.shape(), [1, 2, 2]);
404            assert_close(result[[0, 0, 0]], BLACK_VALUE + 0.0);
405            assert_close(result[[0, 0, 1]], BLACK_VALUE + 128.0 / 255.0);
406            assert_close(result[[0, 1, 0]], BLACK_VALUE + 1.0);
407            assert_close(result[[0, 1, 1]], BLACK_VALUE + 64.0 / 255.0);
408        }
409    }
410
411    #[test]
412    fn test_prepare_image_greyscale_f32() {
413        struct Case {
414            shape: [usize; 3],
415            order: DimOrder,
416        }
417
418        let cases = [
419            Case {
420                shape: [2, 2, 1],
421                order: DimOrder::Hwc,
422            },
423            Case {
424                shape: [1, 2, 2],
425                order: DimOrder::Chw,
426            },
427        ];
428
429        for Case { shape, order } in cases {
430            let data: Vec<f32> = vec![0.0, 0.5, 1.0, 0.25];
431            let tensor = NdTensor::from_data(shape, data);
432            let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
433
434            let result = prepare_image(source);
435
436            assert_eq!(result.shape(), [1, 2, 2]);
437            assert_close(result[[0, 0, 0]], BLACK_VALUE + 0.0);
438            assert_close(result[[0, 0, 1]], BLACK_VALUE + 0.5);
439            assert_close(result[[0, 1, 0]], BLACK_VALUE + 1.0);
440            assert_close(result[[0, 1, 1]], BLACK_VALUE + 0.25);
441        }
442    }
443
444    #[test]
445    fn test_prepare_image_rgb_rgba_u8() {
446        struct Case {
447            data: Vec<u8>,
448            shape: [usize; 3],
449            order: DimOrder,
450            rgb: [u8; 3],
451        }
452
453        let cases = [
454            // RGB HWC
455            Case {
456                data: vec![100, 150, 200],
457                shape: [1, 1, 3],
458                order: DimOrder::Hwc,
459                rgb: [100, 150, 200],
460            },
461            // RGB CHW
462            Case {
463                data: vec![100, 150, 200],
464                shape: [3, 1, 1],
465                order: DimOrder::Chw,
466                rgb: [100, 150, 200],
467            },
468            // RGBA HWC (alpha should be ignored)
469            Case {
470                data: vec![50, 100, 150, 255],
471                shape: [1, 1, 4],
472                order: DimOrder::Hwc,
473                rgb: [50, 100, 150],
474            },
475            // RGBA CHW
476            Case {
477                data: vec![50, 100, 150, 255],
478                shape: [4, 1, 1],
479                order: DimOrder::Chw,
480                rgb: [50, 100, 150],
481            },
482        ];
483
484        for Case {
485            data,
486            shape,
487            order,
488            rgb: [r, g, b],
489        } in cases
490        {
491            let tensor = NdTensor::from_data(shape, data);
492            let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
493
494            let result = prepare_image(source);
495
496            assert_eq!(result.shape(), [1, 1, 1]);
497            let expected =
498                expected_grey_from_rgb(r as f32 / 255.0, g as f32 / 255.0, b as f32 / 255.0);
499            assert_close(result[[0, 0, 0]], expected);
500        }
501    }
502
503    #[test]
504    fn test_prepare_image_rgb_f32() {
505        struct Case {
506            shape: [usize; 3],
507            order: DimOrder,
508        }
509
510        let cases = [
511            Case {
512                shape: [1, 1, 3],
513                order: DimOrder::Hwc,
514            },
515            Case {
516                shape: [3, 1, 1],
517                order: DimOrder::Chw,
518            },
519        ];
520
521        let (r, g, b) = (0.4, 0.6, 0.8);
522
523        for Case { shape, order } in cases {
524            let data: Vec<f32> = vec![r, g, b];
525            let tensor = NdTensor::from_data(shape, data);
526            let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
527
528            let result = prepare_image(source);
529
530            assert_eq!(result.shape(), [1, 1, 1]);
531            let expected = expected_grey_from_rgb(r, g, b);
532            assert_close(result[[0, 0, 0]], expected);
533        }
534    }
535
536    #[test]
537    fn test_prepare_image_multi_pixel_rgb() {
538        // Test both HWC and CHW with a 2x2 image to verify iteration order
539        struct Case {
540            data: Vec<u8>,
541            shape: [usize; 3],
542            order: DimOrder,
543        }
544
545        let cases = [
546            // HWC layout
547            Case {
548                #[rustfmt::skip]
549                data: vec![
550                    255, 0, 0,    // (0,0) red
551                    0, 255, 0,    // (0,1) green
552                    0, 0, 255,    // (1,0) blue
553                    128, 128, 128 // (1,1) grey
554                ],
555                shape: [2, 2, 3],
556                order: DimOrder::Hwc,
557            },
558            // CHW layout (same image, different memory layout)
559            Case {
560                #[rustfmt::skip]
561                data: vec![
562                    // R channel
563                    255, 0,
564                    0, 128,
565                    // G channel
566                    0, 255,
567                    0, 128,
568                    // B channel
569                    0, 0,
570                    255, 128,
571                ],
572                shape: [3, 2, 2],
573                order: DimOrder::Chw,
574            },
575        ];
576
577        let expected_red = expected_grey_from_rgb(1.0, 0.0, 0.0);
578        let expected_green = expected_grey_from_rgb(0.0, 1.0, 0.0);
579        let expected_blue = expected_grey_from_rgb(0.0, 0.0, 1.0);
580        let expected_grey = expected_grey_from_rgb(128.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0);
581
582        for Case { data, shape, order } in cases {
583            let tensor = NdTensor::from_data(shape, data);
584            let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
585
586            let result = prepare_image(source);
587
588            assert_eq!(result.shape(), [1, 2, 2]);
589            assert_close(result[[0, 0, 0]], expected_red);
590            assert_close(result[[0, 0, 1]], expected_green);
591            assert_close(result[[0, 1, 0]], expected_blue);
592            assert_close(result[[0, 1, 1]], expected_grey);
593        }
594    }
595}