1use crate::config::{DEFAULT_ALPHABET, DecodeMethod, DetectionParams, OcrConfig};
19use crate::layout::find_text_lines;
20use crate::preprocess::{ImageSource, prepare_image};
21use crate::text::TextLine;
22use anyhow::{Context, Result, anyhow};
23use rlx_runtime::Device;
24use rten_imageproc::RotatedRect;
25use rten_tensor::prelude::*;
26use rten_tensor::{NdTensor, NdTensorView};
27
28#[cfg(feature = "rlx")]
29use crate::rlx::{RlxTextDetector, RlxTextRecognizer};
30
31#[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
32use crate::inference::{RtenTextDetector, RtenTextRecognizer};
33
34pub struct OcrEngineParams {
36 pub detection_model: Option<std::path::PathBuf>,
37 pub recognition_model: Option<std::path::PathBuf>,
38 pub detection: DetectionParams,
39 pub decode_method: DecodeMethod,
40 pub alphabet: Option<String>,
41 pub allowed_chars: Option<String>,
42 pub device: Device,
43}
44
45impl Default for OcrEngineParams {
46 fn default() -> Self {
47 Self {
48 detection_model: None,
49 recognition_model: None,
50 detection: DetectionParams::default(),
51 decode_method: DecodeMethod::default(),
52 alphabet: None,
53 allowed_chars: None,
54 device: Device::Cpu,
55 }
56 }
57}
58
59pub struct OcrInput {
61 pub(crate) image: NdTensor<f32, 3>,
62}
63
64pub struct OcrEngine {
66 #[cfg(feature = "rlx")]
67 detector: Option<RlxTextDetector>,
68 #[cfg(feature = "rlx")]
69 recognizer: Option<RlxTextRecognizer>,
70 #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
71 detector: Option<RtenTextDetector>,
72 #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
73 recognizer: Option<RtenTextRecognizer>,
74 detection: DetectionParams,
75 decode_method: DecodeMethod,
76 alphabet: String,
77 excluded_char_labels: Option<Vec<usize>>,
78}
79
80impl OcrEngine {
81 pub fn from_paths(
83 detection: impl AsRef<std::path::Path>,
84 recognition: impl AsRef<std::path::Path>,
85 ) -> Result<Self> {
86 Self::from_paths_on_device(detection, recognition, Device::Cpu)
87 }
88
89 pub fn from_paths_on_device(
90 detection: impl AsRef<std::path::Path>,
91 recognition: impl AsRef<std::path::Path>,
92 device: Device,
93 ) -> Result<Self> {
94 Self::new(OcrEngineParams {
95 detection_model: Some(detection.as_ref().to_path_buf()),
96 recognition_model: Some(recognition.as_ref().to_path_buf()),
97 device,
98 ..Default::default()
99 })
100 }
101
102 pub fn new(params: OcrEngineParams) -> Result<Self> {
103 let detection = params.detection;
104 let device = params.device;
105 #[cfg(feature = "rlx")]
106 let detector = params
107 .detection_model
108 .as_ref()
109 .map(|p| RlxTextDetector::from_path(p, detection.clone(), device))
110 .transpose()
111 .context("load detection model")?;
112 #[cfg(feature = "rlx")]
113 let recognizer = params
114 .recognition_model
115 .as_ref()
116 .map(|p| RlxTextRecognizer::from_path(p, device))
117 .transpose()
118 .context("load recognition model")?;
119
120 #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
121 let detector = params
122 .detection_model
123 .as_ref()
124 .map(|p| RtenTextDetector::from_path(p, detection.clone()))
125 .transpose()
126 .context("load detection model")?;
127 #[cfg(all(feature = "rten-inference", not(feature = "rlx")))]
128 let recognizer = params
129 .recognition_model
130 .as_ref()
131 .map(RtenTextRecognizer::from_path)
132 .transpose()
133 .context("load recognition model")?;
134
135 let alphabet = params
136 .alphabet
137 .unwrap_or_else(|| DEFAULT_ALPHABET.to_string());
138 let excluded_char_labels = params.allowed_chars.as_ref().map(|allowed| {
139 alphabet
140 .chars()
141 .enumerate()
142 .filter_map(|(i, ch)| {
143 if allowed.contains(ch) {
144 None
145 } else {
146 Some(i + 1)
147 }
148 })
149 .collect()
150 });
151
152 Ok(Self {
153 detector,
154 recognizer,
155 detection,
156 decode_method: params.decode_method,
157 alphabet,
158 excluded_char_labels,
159 })
160 }
161
162 pub fn from_model_dir(dir: impl AsRef<std::path::Path>) -> Result<Self> {
164 Self::from_model_dir_on_device(dir, Device::Cpu)
165 }
166
167 pub fn from_model_dir_on_device(
169 dir: impl AsRef<std::path::Path>,
170 device: Device,
171 ) -> Result<Self> {
172 let (det, rec) = crate::weights::resolve_model_dir(dir.as_ref())?;
173 Self::from_paths_on_device(det, rec, device)
174 }
175
176 pub fn prepare_input(&self, source: ImageSource<'_>) -> Result<OcrInput> {
177 Ok(OcrInput {
178 image: prepare_image(source),
179 })
180 }
181
182 pub fn detection_threshold(&self) -> f32 {
183 self.detection.text_threshold
184 }
185
186 pub fn detect_words(&self, input: &OcrInput) -> Result<Vec<RotatedRect>> {
187 let detector = self
188 .detector
189 .as_ref()
190 .ok_or_else(|| anyhow!("detection model not configured"))?;
191 detector.detect_words(input.image.view())
192 }
193
194 pub fn detect_text_pixels(&self, input: &OcrInput) -> Result<NdTensor<f32, 2>> {
195 let detector = self
196 .detector
197 .as_ref()
198 .ok_or_else(|| anyhow!("detection model not configured"))?;
199 detector.detect_text_pixels(input.image.view())
200 }
201
202 pub fn find_text_lines(
203 &self,
204 _input: &OcrInput,
205 words: &[RotatedRect],
206 ) -> Vec<Vec<RotatedRect>> {
207 find_text_lines(words)
208 }
209
210 pub fn prepare_recognition_input(
211 &self,
212 input: &OcrInput,
213 line: &[RotatedRect],
214 ) -> Result<NdTensor<f32, 2>> {
215 let recognizer = self
216 .recognizer
217 .as_ref()
218 .ok_or_else(|| anyhow!("recognition model not configured"))?;
219 Ok(recognizer.prepare_input(input.image.view(), line))
220 }
221
222 pub fn recognize_text(
223 &self,
224 input: &OcrInput,
225 lines: &[Vec<RotatedRect>],
226 ) -> Result<Vec<Option<TextLine>>> {
227 let recognizer = self
228 .recognizer
229 .as_ref()
230 .ok_or_else(|| anyhow!("recognition model not configured"))?;
231 recognizer.recognize_text_lines(
232 input.image.view(),
233 lines,
234 self.decode_method,
235 &self.alphabet,
236 self.excluded_char_labels.as_deref(),
237 )
238 }
239
240 pub fn get_text(&self, input: &OcrInput) -> Result<String> {
241 let words = self.detect_words(input)?;
242 let lines = self.find_text_lines(input, &words);
243 let recognized = self.recognize_text(input, &lines)?;
244 Ok(recognized
245 .into_iter()
246 .filter_map(|l| l.map(|tl| tl.text()))
247 .collect::<Vec<_>>()
248 .join("\n"))
249 }
250
251 pub fn config(&self) -> OcrConfig {
252 OcrConfig {
253 detection: self.detection.clone(),
254 decode_method: self.decode_method,
255 alphabet: self.alphabet.clone(),
256 }
257 }
258}
259
260pub fn input_image(input: &OcrInput) -> NdTensorView<'_, f32, 3> {
261 input.image.view()
262}
263
264pub fn ocr_rgb_bytes(engine: &OcrEngine, rgb: &[u8], width: u32, height: u32) -> Result<String> {
265 let source = ImageSource::from_bytes(rgb, (width, height))?;
266 let input = engine.prepare_input(source)?;
267 engine.get_text(&input)
268}