1use crate::capabilities::validate_device;
19use crate::config::OcrConfig;
20use crate::engine::{OcrEngine, OcrEngineParams};
21use crate::text::TextLine;
22use crate::weights::resolve_model_dir;
23use anyhow::{Context, Result, anyhow};
24use rlx_runtime::Device;
25use rten_imageproc::RotatedRect;
26use std::path::{Path, PathBuf};
27
28#[derive(Debug, Clone)]
30pub struct OcrOutput {
31 pub text: String,
32 pub lines: Vec<Option<TextLine>>,
33 pub words: Vec<RotatedRect>,
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct OcrRunnerBuilder {
39 model_dir: Option<PathBuf>,
40 detection_model: Option<PathBuf>,
41 recognition_model: Option<PathBuf>,
42 device: Option<Device>,
43 alphabet: Option<String>,
44}
45
46impl OcrRunnerBuilder {
47 pub fn model_dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
48 self.model_dir = Some(dir.into());
49 self
50 }
51
52 pub fn detection_model<P: Into<PathBuf>>(mut self, p: P) -> Self {
53 self.detection_model = Some(p.into());
54 self
55 }
56
57 pub fn recognition_model<P: Into<PathBuf>>(mut self, p: P) -> Self {
58 self.recognition_model = Some(p.into());
59 self
60 }
61
62 pub fn device(mut self, d: Device) -> Self {
63 self.device = Some(d);
64 self
65 }
66
67 pub fn alphabet(mut self, alphabet: impl Into<String>) -> Self {
68 self.alphabet = Some(alphabet.into());
69 self
70 }
71
72 pub fn build(self) -> Result<OcrRunner> {
73 let (detection, recognition) = match (
74 self.detection_model,
75 self.recognition_model,
76 self.model_dir,
77 ) {
78 (Some(d), Some(r), _) => (d, r),
79 (_, _, Some(dir)) => resolve_model_dir(&dir)?,
80 _ => {
81 return Err(anyhow!(
82 "provide model_dir(...) or both detection_model(...) and recognition_model(...)"
83 ));
84 }
85 };
86 let device = self.device.unwrap_or(Device::Cpu);
87 validate_device(device)?;
88
89 let engine = OcrEngine::new(OcrEngineParams {
90 detection_model: Some(detection),
91 recognition_model: Some(recognition),
92 alphabet: self.alphabet,
93 device,
94 ..Default::default()
95 })?;
96
97 Ok(OcrRunner { engine, device })
98 }
99}
100
101pub struct OcrRunner {
103 engine: OcrEngine,
104 device: Device,
105}
106
107impl OcrRunner {
108 pub fn builder() -> OcrRunnerBuilder {
109 OcrRunnerBuilder::default()
110 }
111
112 pub fn engine(&self) -> &OcrEngine {
113 &self.engine
114 }
115
116 pub fn device(&self) -> Device {
117 self.device
118 }
119
120 pub fn config(&self) -> OcrConfig {
121 self.engine.config()
122 }
123
124 pub fn predict_path(&self, path: &Path) -> Result<OcrOutput> {
126 let img = image::open(path)
127 .with_context(|| format!("open image {path:?}"))?
128 .into_rgb8();
129 let (w, h) = img.dimensions();
130 self.predict_rgb(img.as_raw(), w, h)
131 }
132
133 pub fn predict_rgb(&self, rgb: &[u8], width: u32, height: u32) -> Result<OcrOutput> {
135 let source = crate::ImageSource::from_bytes(rgb, (width, height))?;
136 let input = self.engine.prepare_input(source)?;
137 let words = self.engine.detect_words(&input)?;
138 let line_rects = self.engine.find_text_lines(&input, &words);
139 let lines = self.engine.recognize_text(&input, &line_rects)?;
140 let text = lines
141 .iter()
142 .filter_map(|l| l.as_ref().map(TextLine::text))
143 .collect::<Vec<_>>()
144 .join("\n");
145 Ok(OcrOutput { text, lines, words })
146 }
147
148 pub fn predict_text(&self, path: &Path) -> Result<String> {
150 Ok(self.predict_path(path)?.text)
151 }
152}