Skip to main content

rlx_ocr/
engine.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level OCR engine — detect, layout, recognize.
17
18use crate::config::{DEFAULT_ALPHABET, DecodeMethod, DetectionParams, OcrConfig};
19use crate::layout::find_text_lines;
20use crate::preprocess::{ImageSource, prepare_image};
21use crate::text::TextLine;
22use anyhow::{Context, Result, anyhow};
23use rlx_runtime::Device;
24use rten_imageproc::RotatedRect;
25use rten_tensor::prelude::*;
26use rten_tensor::{NdTensor, NdTensorView};
27
28#[cfg(feature = "rlx")]
29use crate::rlx::{RlxTextDetector, RlxTextRecognizer};
30
31#[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
32use crate::inference::{RtenTextDetector, RtenTextRecognizer};
33
34/// Parameters for constructing an [`OcrEngine`].
35pub struct OcrEngineParams {
36    pub detection_model: Option<std::path::PathBuf>,
37    pub recognition_model: Option<std::path::PathBuf>,
38    pub detection: DetectionParams,
39    pub decode_method: DecodeMethod,
40    pub alphabet: Option<String>,
41    pub allowed_chars: Option<String>,
42    pub device: Device,
43}
44
45impl Default for OcrEngineParams {
46    fn default() -> Self {
47        Self {
48            detection_model: None,
49            recognition_model: None,
50            detection: DetectionParams::default(),
51            decode_method: DecodeMethod::default(),
52            alphabet: None,
53            allowed_chars: None,
54            device: Device::Cpu,
55        }
56    }
57}
58
59/// Preprocessed greyscale input image `[1, H, W]`.
60pub struct OcrInput {
61    pub(crate) image: NdTensor<f32, 3>,
62}
63
64/// End-to-end OCR pipeline (ocrs-compatible API).
65pub struct OcrEngine {
66    #[cfg(feature = "rlx")]
67    detector: Option<RlxTextDetector>,
68    #[cfg(feature = "rlx")]
69    recognizer: Option<RlxTextRecognizer>,
70    #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
71    detector: Option<RtenTextDetector>,
72    #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
73    recognizer: Option<RtenTextRecognizer>,
74    detection: DetectionParams,
75    decode_method: DecodeMethod,
76    alphabet: String,
77    excluded_char_labels: Option<Vec<usize>>,
78}
79
80impl OcrEngine {
81    /// Build from explicit model paths (`.safetensors` for native RLX; `.rten` only with `rten-inference`).
82    pub fn from_paths(
83        detection: impl AsRef<std::path::Path>,
84        recognition: impl AsRef<std::path::Path>,
85    ) -> Result<Self> {
86        Self::from_paths_on_device(detection, recognition, Device::Cpu)
87    }
88
89    pub fn from_paths_on_device(
90        detection: impl AsRef<std::path::Path>,
91        recognition: impl AsRef<std::path::Path>,
92        device: Device,
93    ) -> Result<Self> {
94        Self::new(OcrEngineParams {
95            detection_model: Some(detection.as_ref().to_path_buf()),
96            recognition_model: Some(recognition.as_ref().to_path_buf()),
97            device,
98            ..Default::default()
99        })
100    }
101
102    pub fn new(params: OcrEngineParams) -> Result<Self> {
103        let detection = params.detection;
104        let device = params.device;
105        #[cfg(feature = "rlx")]
106        let detector = params
107            .detection_model
108            .as_ref()
109            .map(|p| RlxTextDetector::from_path(p, detection.clone(), device))
110            .transpose()
111            .context("load detection model")?;
112        #[cfg(feature = "rlx")]
113        let recognizer = params
114            .recognition_model
115            .as_ref()
116            .map(|p| RlxTextRecognizer::from_path(p, device))
117            .transpose()
118            .context("load recognition model")?;
119
120        #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
121        let detector = params
122            .detection_model
123            .as_ref()
124            .map(|p| RtenTextDetector::from_path(p, detection.clone()))
125            .transpose()
126            .context("load detection model")?;
127        #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
128        let recognizer = params
129            .recognition_model
130            .as_ref()
131            .map(RtenTextRecognizer::from_path)
132            .transpose()
133            .context("load recognition model")?;
134
135        let alphabet = params
136            .alphabet
137            .unwrap_or_else(|| DEFAULT_ALPHABET.to_string());
138        let excluded_char_labels = params.allowed_chars.as_ref().map(|allowed| {
139            alphabet
140                .chars()
141                .enumerate()
142                .filter_map(|(i, ch)| {
143                    if allowed.contains(ch) {
144                        None
145                    } else {
146                        Some(i + 1)
147                    }
148                })
149                .collect()
150        });
151
152        Ok(Self {
153            detector,
154            recognizer,
155            detection,
156            decode_method: params.decode_method,
157            alphabet,
158            excluded_char_labels,
159        })
160    }
161
162    /// Load default HuggingFace checkpoint filenames from a model directory (CPU).
163    pub fn from_model_dir(dir: impl AsRef<std::path::Path>) -> Result<Self> {
164        Self::from_model_dir_on_device(dir, Device::Cpu)
165    }
166
167    /// Load checkpoints from `dir` and compile graphs on `device`.
168    pub fn from_model_dir_on_device(
169        dir: impl AsRef<std::path::Path>,
170        device: Device,
171    ) -> Result<Self> {
172        let (det, rec) = crate::weights::resolve_model_dir(dir.as_ref())?;
173        Self::from_paths_on_device(det, rec, device)
174    }
175
176    pub fn prepare_input(&self, source: ImageSource<'_>) -> Result<OcrInput> {
177        Ok(OcrInput {
178            image: prepare_image(source),
179        })
180    }
181
182    pub fn detection_threshold(&self) -> f32 {
183        self.detection.text_threshold
184    }
185
186    pub fn detect_words(&self, input: &OcrInput) -> Result<Vec<RotatedRect>> {
187        let detector = self
188            .detector
189            .as_ref()
190            .ok_or_else(|| anyhow!("detection model not configured"))?;
191        detector.detect_words(input.image.view())
192    }
193
194    pub fn detect_text_pixels(&self, input: &OcrInput) -> Result<NdTensor<f32, 2>> {
195        let detector = self
196            .detector
197            .as_ref()
198            .ok_or_else(|| anyhow!("detection model not configured"))?;
199        detector.detect_text_pixels(input.image.view())
200    }
201
202    pub fn find_text_lines(
203        &self,
204        _input: &OcrInput,
205        words: &[RotatedRect],
206    ) -> Vec<Vec<RotatedRect>> {
207        find_text_lines(words)
208    }
209
210    pub fn prepare_recognition_input(
211        &self,
212        input: &OcrInput,
213        line: &[RotatedRect],
214    ) -> Result<NdTensor<f32, 2>> {
215        let recognizer = self
216            .recognizer
217            .as_ref()
218            .ok_or_else(|| anyhow!("recognition model not configured"))?;
219        Ok(recognizer.prepare_input(input.image.view(), line))
220    }
221
222    pub fn recognize_text(
223        &self,
224        input: &OcrInput,
225        lines: &[Vec<RotatedRect>],
226    ) -> Result<Vec<Option<TextLine>>> {
227        let recognizer = self
228            .recognizer
229            .as_ref()
230            .ok_or_else(|| anyhow!("recognition model not configured"))?;
231        recognizer.recognize_text_lines(
232            input.image.view(),
233            lines,
234            self.decode_method,
235            &self.alphabet,
236            self.excluded_char_labels.as_deref(),
237        )
238    }
239
240    pub fn get_text(&self, input: &OcrInput) -> Result<String> {
241        let words = self.detect_words(input)?;
242        let lines = self.find_text_lines(input, &words);
243        let recognized = self.recognize_text(input, &lines)?;
244        Ok(recognized
245            .into_iter()
246            .filter_map(|l| l.map(|tl| tl.text()))
247            .collect::<Vec<_>>()
248            .join("\n"))
249    }
250
251    pub fn config(&self) -> OcrConfig {
252        OcrConfig {
253            detection: self.detection.clone(),
254            decode_method: self.decode_method,
255            alphabet: self.alphabet.clone(),
256        }
257    }
258}
259
260pub fn input_image(input: &OcrInput) -> NdTensorView<'_, f32, 3> {
261    input.image.view()
262}
263
264pub fn ocr_rgb_bytes(engine: &OcrEngine, rgb: &[u8], width: u32, height: u32) -> Result<String> {
265    let source = ImageSource::from_bytes(rgb, (width, height))?;
266    let input = engine.prepare_input(source)?;
267    engine.get_text(&input)
268}