Skip to main content

rlx_ocr/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Greyscale image preprocessing for OCR models.
17
18use rten_tensor::prelude::*;
19use rten_tensor::{NdTensor, NdTensorView};
20use std::fmt;
21
22/// Normalized greyscale background value used by ocrs models (matches `ocrs` 0.12.x).
23pub const BLACK_VALUE: f32 = -0.5;
24
25/// Errors when constructing an [`ImageSource`].
26#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum ImageSourceError {
28    UnsupportedChannelCount,
29    InvalidDataLength,
30}
31
32impl fmt::Display for ImageSourceError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::UnsupportedChannelCount => f.write_str("channel count is not 1, 3 or 4"),
36            Self::InvalidDataLength => {
37                f.write_str("data length is not a multiple of width * height")
38            }
39        }
40    }
41}
42
43impl std::error::Error for ImageSourceError {}
44
45/// Pixel layout for image tensors.
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum DimOrder {
48    Hwc,
49    Chw,
50}
51
52enum ImagePixels<'a> {
53    #[allow(dead_code)]
54    Floats(NdTensorView<'a, f32, 3>),
55    Bytes(NdTensorView<'a, u8, 3>),
56    FloatsOwned(NdTensor<f32, 3>),
57}
58
59/// Input image for [`crate::OcrEngine::prepare_input`].
60pub struct ImageSource<'a> {
61    data: ImagePixels<'a>,
62    order: DimOrder,
63}
64
65impl<'a> ImageSource<'a> {
66    /// RGB/RGBA/greyscale bytes in HWC order (`image` crate layout).
67    pub fn from_bytes(bytes: &'a [u8], dimensions: (u32, u32)) -> Result<Self, ImageSourceError> {
68        let (width, height) = dimensions;
69        let channel_len = (width as usize).saturating_mul(height as usize);
70        if channel_len == 0 {
71            return Err(ImageSourceError::UnsupportedChannelCount);
72        }
73        if !bytes.len().is_multiple_of(channel_len) {
74            return Err(ImageSourceError::InvalidDataLength);
75        }
76        let chans = bytes.len() / channel_len;
77        if !matches!(chans, 1 | 3 | 4) {
78            return Err(ImageSourceError::UnsupportedChannelCount);
79        }
80        let view = NdTensorView::from_data([height as usize, width as usize, chans], bytes);
81        Ok(Self {
82            data: ImagePixels::Bytes(view),
83            order: DimOrder::Hwc,
84        })
85    }
86
87    /// Existing CHW or HWC float tensor (copied).
88    pub fn from_tensor(
89        tensor: NdTensorView<'_, f32, 3>,
90        order: DimOrder,
91    ) -> Result<ImageSource<'static>, ImageSourceError> {
92        let chans = match order {
93            DimOrder::Hwc => tensor.size(2),
94            DimOrder::Chw => tensor.size(0),
95        };
96        if chans == 0 || !matches!(chans, 1 | 3 | 4) {
97            return Err(ImageSourceError::UnsupportedChannelCount);
98        }
99        let owned = NdTensor::from_data(tensor.shape(), tensor.to_vec());
100        Ok(ImageSource {
101            data: ImagePixels::FloatsOwned(owned),
102            order,
103        })
104    }
105}
106
107/// Convert an image to a normalized greyscale CHW tensor `[1, H, W]`.
108pub fn prepare_image(img: ImageSource<'_>) -> NdTensor<f32, 3> {
109    match (&img.data, img.order) {
110        (ImagePixels::Floats(f), DimOrder::Hwc) => prepare_floats::<true>(f.view()),
111        (ImagePixels::Floats(f), DimOrder::Chw) => prepare_floats::<false>(f.view()),
112        (ImagePixels::FloatsOwned(f), DimOrder::Hwc) => prepare_floats::<true>(f.view()),
113        (ImagePixels::FloatsOwned(f), DimOrder::Chw) => prepare_floats::<false>(f.view()),
114        (ImagePixels::Bytes(b), DimOrder::Hwc) => prepare_bytes::<true>(b.view()),
115        (ImagePixels::Bytes(b), DimOrder::Chw) => prepare_bytes::<false>(b.view()),
116    }
117}
118
119fn prepare_floats<const CHANS_LAST: bool>(floats: NdTensorView<'_, f32, 3>) -> NdTensor<f32, 3> {
120    const ITU: [f32; 3] = [0.299, 0.587, 0.114];
121    let n = if CHANS_LAST {
122        floats.shape()[2]
123    } else {
124        floats.shape()[0]
125    };
126    match n {
127        1 => convert_pixels::<f32, 1, 1, CHANS_LAST>(floats, [1.]),
128        3 => convert_pixels::<f32, 3, 3, CHANS_LAST>(floats, ITU),
129        4 => convert_pixels::<f32, 4, 3, CHANS_LAST>(floats, ITU),
130        _ => panic!("expected greyscale, RGB or RGBA input image"),
131    }
132}
133
134fn prepare_bytes<const CHANS_LAST: bool>(bytes: NdTensorView<'_, u8, 3>) -> NdTensor<f32, 3> {
135    const ITU: [f32; 3] = [0.299, 0.587, 0.114];
136    let weights = ITU.map(|w| w / 255.0);
137    let n = if CHANS_LAST {
138        bytes.shape()[2]
139    } else {
140        bytes.shape()[0]
141    };
142    match n {
143        1 => convert_pixels::<u8, 1, 1, CHANS_LAST>(bytes, [1. / 255.0]),
144        3 => convert_pixels::<u8, 3, 3, CHANS_LAST>(bytes, weights),
145        4 => convert_pixels::<u8, 4, 3, CHANS_LAST>(bytes, weights),
146        _ => panic!("expected greyscale, RGB or RGBA input image"),
147    }
148}
149
150fn convert_pixels<
151    T: Copy + Into<f32>,
152    const PIXEL_STRIDE: usize,
153    const CHANS: usize,
154    const CHANS_LAST: bool,
155>(
156    src: NdTensorView<'_, T, 3>,
157    chan_weights: [f32; CHANS],
158) -> NdTensor<f32, 3> {
159    let [height, width, chans] = if CHANS_LAST {
160        src.shape()
161    } else {
162        let [c, h, w] = src.shape();
163        [h, w, c]
164    };
165    assert_eq!(chans, PIXEL_STRIDE);
166
167    let mut out_pixels = Vec::with_capacity(height * width);
168    if CHANS_LAST {
169        let src = src.to_contiguous();
170        let mut iter = src.data().chunks_exact(PIXEL_STRIDE);
171        debug_assert!(iter.remainder().is_empty());
172        for in_pixel in iter.by_ref() {
173            let mut pixel = BLACK_VALUE;
174            for (c, &w) in chan_weights.iter().enumerate() {
175                pixel += in_pixel[c].into() * w;
176            }
177            out_pixels.push(pixel);
178        }
179    } else {
180        for y in 0..height {
181            out_pixels.extend((0..width).map(|x| {
182                let mut pixel = BLACK_VALUE;
183                for (c, &w) in chan_weights.iter().enumerate() {
184                    pixel += src[[c, y, x]].into() * w;
185                }
186                pixel
187            }));
188        }
189    }
190    NdTensor::from_data([1, height, width], out_pixels)
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn preprocess_greyscale_bytes() {
199        let out = prepare_image(ImageSource::from_bytes(&[0, 128, 255, 64], (2, 2)).unwrap());
200        assert_eq!(out.shape(), [1, 2, 2]);
201        assert!((out[[0, 0, 0]] - BLACK_VALUE).abs() < 1e-5);
202        assert!((out[[0, 0, 1]] - (BLACK_VALUE + 128.0 / 255.0)).abs() < 1e-5);
203    }
204}