1use crate::capabilities::validate_device;
17use crate::config::{DecodeMethod, RECOGNITION_INPUT_HEIGHT};
18use crate::model::{
19 NUM_CLASSES, RecognitionGraphConfig, build_recognition_graph, log_softmax_last_axis,
20};
21use crate::recognition::line_batch::{
22 LineRecResult, TextRecLine, bounding_rect, filter_excluded_char_labels, line_polygon,
23 prepare_text_line, prepare_text_line_batch, resized_line_width,
24 text_lines_from_recognition_results,
25};
26use crate::text::TextLine;
27use crate::weights::{
28 HF_RECOGNITION_ST, HF_RECOGNITION_ST_FULL, SafetensorsFile, prefer_safetensors_path,
29};
30use anyhow::{Result, anyhow};
31use rlx_core::flow_bridge::compile_options_for_profile;
32use rlx_core::flow_util::attach_built_params;
33use rlx_flow::CompileProfile;
34use rlx_runtime::{CompiledGraph, Device, Session};
35use rten_imageproc::{Polygon, Rect, RotatedRect};
36use rten_tensor::prelude::*;
37use rten_tensor::{NdTensor, NdTensorView};
38use std::collections::HashMap;
39use std::path::Path;
40use std::sync::Mutex;
41
42pub struct RlxTextRecognizer {
44 weights: SafetensorsFile,
45 device: Device,
46 cache: Mutex<HashMap<usize, CompiledGraph>>,
47 input_height: u32,
48}
49
50impl RlxTextRecognizer {
51 pub fn from_path(path: impl AsRef<Path>, device: Device) -> Result<Self> {
52 Self::from_safetensors(path.as_ref(), device)
53 }
54
55 pub fn from_safetensors(path: &Path, device: Device) -> Result<Self> {
56 validate_device(device)?;
57 if !path.is_file() {
58 anyhow::bail!("recognition weights not found: {path:?}");
59 }
60 Ok(Self {
61 weights: SafetensorsFile::open(path)?,
62 device,
63 cache: Mutex::new(HashMap::new()),
64 input_height: RECOGNITION_INPUT_HEIGHT,
65 })
66 }
67
68 pub fn from_model_dir(dir: &Path, device: Device) -> Result<Self> {
69 validate_device(device)?;
70 let path = prefer_safetensors_path(dir, HF_RECOGNITION_ST, HF_RECOGNITION_ST_FULL);
71 Self::from_safetensors(&path, device)
72 }
73
74 fn ensure_compiled(&self, compile_width: usize, max_lines_per_group: usize) -> Result<()> {
75 let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
76 if cache.contains_key(&compile_width) {
77 return Ok(());
78 }
79 let mut wm = self.weights.weight_map()?;
80 let (graph, params) = build_recognition_graph(
81 &mut wm,
82 RecognitionGraphConfig {
83 batch: max_lines_per_group,
84 width: compile_width,
85 },
86 )?;
87 let opts = compile_options_for_profile(&CompileProfile::encoder(), self.device);
88 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
89 attach_built_params(&mut compiled, params, &[]);
90 cache.insert(compile_width, compiled);
91 Ok(())
92 }
93
94 pub fn run_batch_logits(&self, input: NdTensor<f32, 4>) -> Result<NdTensor<f32, 3>> {
96 let w = input.size(3);
97 let batch = input.size(0);
98 self.ensure_compiled(w, batch)?;
99 let input_vec: Vec<f32> = input.iter().copied().collect();
100 let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
101 let compiled = cache.get_mut(&w).unwrap();
102 let mut flat = compiled
103 .run(&[("image", input_vec.as_slice())])
104 .into_iter()
105 .next()
106 .ok_or_else(|| anyhow!("recognition returned no output"))?;
107 let seq_len = flat.len() / (batch * NUM_CLASSES);
108 log_softmax_last_axis(&mut flat, NUM_CLASSES);
109 Ok(NdTensor::from_data([batch, seq_len, NUM_CLASSES], flat))
110 }
111
112 pub fn prepare_input(
113 &self,
114 image: NdTensorView<f32, 3>,
115 line: &[RotatedRect],
116 ) -> NdTensor<f32, 2> {
117 let [_, img_height, img_width] = image.shape();
118 let page_rect = Rect::from_hw(img_height as i32, img_width as i32);
119 let line_rect = bounding_rect(line.iter())
120 .expect("line has no words")
121 .integral_bounding_rect();
122 let line_poly = Polygon::new(line_polygon(line));
123 let resized_width = resized_line_width(
124 line_rect.width(),
125 line_rect.height(),
126 self.input_height as i32,
127 );
128 prepare_text_line(
129 image,
130 page_rect,
131 &line_poly,
132 resized_width,
133 self.input_height as usize,
134 )
135 }
136
137 fn run_group(
138 &self,
139 image: &NdTensorView<f32, 3>,
140 lines: Vec<TextRecLine>,
141 page_rect: Rect,
142 compile_width: usize,
143 rec_img_height: u32,
144 decode_method: DecodeMethod,
145 alphabet_len: usize,
146 excluded_char_labels: Option<&[usize]>,
147 ) -> Result<Vec<LineRecResult>> {
148 self.ensure_compiled(compile_width, lines.len().max(1))?;
149 let rec_input = prepare_text_line_batch(
150 image,
151 &lines,
152 page_rect,
153 rec_img_height as usize,
154 compile_width,
155 );
156 let input: Vec<f32> = rec_input.iter().copied().collect();
157 let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
158 let compiled = cache.get_mut(&compile_width).unwrap();
159 let outputs = compiled.run(&[("image", input.as_slice())]);
160 let flat_out = outputs
161 .into_iter()
162 .next()
163 .ok_or_else(|| anyhow!("recognition returned no output"))?;
164
165 let n_classes = alphabet_len + 1;
166 let seq_len = flat_out.len() / (lines.len() * n_classes);
167 let n = lines.len() * seq_len * n_classes;
168 let mut logits = flat_out[..n].to_vec();
169 log_softmax_last_axis(&mut logits, n_classes);
170 let mut rec_output = NdTensor::from_data([lines.len(), seq_len, n_classes], logits);
171
172 if n_classes != rec_output.size(2) {
173 return Err(anyhow!(
174 "recognition classes ({}) != alphabet+blank ({n_classes})",
175 rec_output.size(2)
176 ));
177 }
178 let ctc_input_len = rec_output.size(1);
179
180 Ok(lines
181 .into_iter()
182 .enumerate()
183 .map(|(group_line_index, line)| {
184 let mut input_seq_slice = rec_output.slice_mut([group_line_index]);
185 let input_seq =
186 filter_excluded_char_labels(excluded_char_labels, &mut input_seq_slice);
187 let ctc_output = crate::ctc::decode(input_seq, decode_method);
188 LineRecResult {
189 line,
190 rec_input_len: compile_width,
191 ctc_input_len,
192 ctc_output,
193 }
194 })
195 .collect())
196 }
197
198 pub fn recognize_text_lines(
199 &self,
200 image: NdTensorView<f32, 3>,
201 lines: &[Vec<RotatedRect>],
202 decode_method: DecodeMethod,
203 alphabet: &str,
204 excluded_char_labels: Option<&[usize]>,
205 ) -> Result<Vec<Option<TextLine>>> {
206 let [_, img_height, img_width] = image.shape();
207 let page_rect = Rect::from_hw(img_height as i32, img_width as i32);
208 let rec_img_height = self.input_height;
209
210 let mut line_groups: HashMap<i32, Vec<TextRecLine>> = HashMap::new();
211 for (line_index, word_rects) in lines.iter().enumerate() {
212 let line_rect = bounding_rect(word_rects.iter())
213 .expect("line has no words")
214 .integral_bounding_rect();
215 let resized_width =
216 resized_line_width(line_rect.width(), line_rect.height(), rec_img_height as i32);
217 let group_width = resized_width.next_multiple_of(50);
218 line_groups
219 .entry(group_width as i32)
220 .or_default()
221 .push(TextRecLine {
222 index: line_index,
223 region: Polygon::new(line_polygon(word_rects)),
224 resized_width,
225 });
226 }
227
228 let max_lines_per_group = 20;
229 let line_groups: Vec<(usize, Vec<TextRecLine>)> = line_groups
230 .into_iter()
231 .flat_map(|(group_width, lines)| {
232 lines
233 .chunks(max_lines_per_group)
234 .map(|chunk| (group_width as usize, chunk.to_vec()))
235 .collect::<Vec<_>>()
236 })
237 .collect();
238
239 let alphabet_len = alphabet.chars().count();
240 for &(compile_width, ref chunk) in &line_groups {
241 self.ensure_compiled(compile_width, chunk.len().max(1))?;
242 }
243
244 let group_results: Result<Vec<Vec<LineRecResult>>> = line_groups
245 .into_iter()
246 .map(|(compile_width, lines)| {
247 self.run_group(
248 &image,
249 lines,
250 page_rect,
251 compile_width,
252 rec_img_height,
253 decode_method,
254 alphabet_len,
255 excluded_char_labels,
256 )
257 })
258 .collect();
259
260 let mut line_rec_results: Vec<LineRecResult> =
261 group_results?.into_iter().flatten().collect();
262 line_rec_results.sort_by_key(|r| r.line.index);
263 Ok(text_lines_from_recognition_results(
264 &line_rec_results,
265 alphabet,
266 ))
267 }
268}