Skip to main content

rlx_ocr/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level OCR runner with image loading.
17
18use crate::capabilities::validate_device;
19use crate::config::OcrConfig;
20use crate::engine::{OcrEngine, OcrEngineParams};
21use crate::text::TextLine;
22use crate::weights::resolve_model_dir;
23use anyhow::{Context, Result, anyhow};
24use rlx_runtime::Device;
25use rten_imageproc::RotatedRect;
26use std::path::{Path, PathBuf};
27
28/// Structured OCR output.
29#[derive(Debug, Clone)]
30pub struct OcrOutput {
31    pub text: String,
32    pub lines: Vec<Option<TextLine>>,
33    pub words: Vec<RotatedRect>,
34}
35
36/// Builder for [`OcrRunner`] (mirrors whisper / dinov2 runners).
37#[derive(Debug, Clone, Default)]
38pub struct OcrRunnerBuilder {
39    model_dir: Option<PathBuf>,
40    detection_model: Option<PathBuf>,
41    recognition_model: Option<PathBuf>,
42    device: Option<Device>,
43    alphabet: Option<String>,
44}
45
46impl OcrRunnerBuilder {
47    pub fn model_dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
48        self.model_dir = Some(dir.into());
49        self
50    }
51
52    pub fn detection_model<P: Into<PathBuf>>(mut self, p: P) -> Self {
53        self.detection_model = Some(p.into());
54        self
55    }
56
57    pub fn recognition_model<P: Into<PathBuf>>(mut self, p: P) -> Self {
58        self.recognition_model = Some(p.into());
59        self
60    }
61
62    pub fn device(mut self, d: Device) -> Self {
63        self.device = Some(d);
64        self
65    }
66
67    pub fn alphabet(mut self, alphabet: impl Into<String>) -> Self {
68        self.alphabet = Some(alphabet.into());
69        self
70    }
71
72    pub fn build(self) -> Result<OcrRunner> {
73        let (detection, recognition) = match (
74            self.detection_model,
75            self.recognition_model,
76            self.model_dir,
77        ) {
78            (Some(d), Some(r), _) => (d, r),
79            (_, _, Some(dir)) => resolve_model_dir(&dir)?,
80            _ => {
81                return Err(anyhow!(
82                    "provide model_dir(...) or both detection_model(...) and recognition_model(...)"
83                ));
84            }
85        };
86        let device = self.device.unwrap_or(Device::Cpu);
87        validate_device(device)?;
88
89        let engine = OcrEngine::new(OcrEngineParams {
90            detection_model: Some(detection),
91            recognition_model: Some(recognition),
92            alphabet: self.alphabet,
93            device,
94            ..Default::default()
95        })?;
96
97        Ok(OcrRunner { engine, device })
98    }
99}
100
101/// OCR session wrapping a fully loaded [`OcrEngine`].
102pub struct OcrRunner {
103    engine: OcrEngine,
104    device: Device,
105}
106
107impl OcrRunner {
108    pub fn builder() -> OcrRunnerBuilder {
109        OcrRunnerBuilder::default()
110    }
111
112    pub fn engine(&self) -> &OcrEngine {
113        &self.engine
114    }
115
116    pub fn device(&self) -> Device {
117        self.device
118    }
119
120    pub fn config(&self) -> OcrConfig {
121        self.engine.config()
122    }
123
124    /// Run OCR on an image file (JPEG/PNG via `image` crate).
125    pub fn predict_path(&self, path: &Path) -> Result<OcrOutput> {
126        let img = image::open(path)
127            .with_context(|| format!("open image {path:?}"))?
128            .into_rgb8();
129        let (w, h) = img.dimensions();
130        self.predict_rgb(img.as_raw(), w, h)
131    }
132
133    /// Run OCR on RGB8 bytes.
134    pub fn predict_rgb(&self, rgb: &[u8], width: u32, height: u32) -> Result<OcrOutput> {
135        let source = crate::ImageSource::from_bytes(rgb, (width, height))?;
136        let input = self.engine.prepare_input(source)?;
137        let words = self.engine.detect_words(&input)?;
138        let line_rects = self.engine.find_text_lines(&input, &words);
139        let lines = self.engine.recognize_text(&input, &line_rects)?;
140        let text = lines
141            .iter()
142            .filter_map(|l| l.as_ref().map(TextLine::text))
143            .collect::<Vec<_>>()
144            .join("\n");
145        Ok(OcrOutput { text, lines, words })
146    }
147
148    /// Convenience: text only.
149    pub fn predict_text(&self, path: &Path) -> Result<String> {
150        Ok(self.predict_path(path)?.text)
151    }
152}