Skip to main content

rlx_ocr/rlx/
recognition.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::capabilities::validate_device;
17use crate::config::{DecodeMethod, RECOGNITION_INPUT_HEIGHT};
18use crate::model::{
19    NUM_CLASSES, RecognitionGraphConfig, build_recognition_graph, log_softmax_last_axis,
20};
21use crate::recognition::line_batch::{
22    LineRecResult, TextRecLine, bounding_rect, filter_excluded_char_labels, line_polygon,
23    prepare_text_line, prepare_text_line_batch, resized_line_width,
24    text_lines_from_recognition_results,
25};
26use crate::text::TextLine;
27use crate::weights::{
28    HF_RECOGNITION_ST, HF_RECOGNITION_ST_FULL, SafetensorsFile, prefer_safetensors_path,
29};
30use anyhow::{Result, anyhow};
31use rlx_core::flow_bridge::compile_options_for_profile;
32use rlx_core::flow_util::attach_built_params;
33use rlx_flow::CompileProfile;
34use rlx_runtime::{CompiledGraph, Device, Session};
35use rten_imageproc::{Polygon, Rect, RotatedRect};
36use rten_tensor::prelude::*;
37use rten_tensor::{NdTensor, NdTensorView};
38use std::collections::HashMap;
39use std::path::Path;
40use std::sync::Mutex;
41
42/// Recognition backend using compiled native RLX CRNN + GRU graphs (per padded width).
43pub struct RlxTextRecognizer {
44    weights: SafetensorsFile,
45    device: Device,
46    cache: Mutex<HashMap<usize, CompiledGraph>>,
47    input_height: u32,
48}
49
50impl RlxTextRecognizer {
51    pub fn from_path(path: impl AsRef<Path>, device: Device) -> Result<Self> {
52        Self::from_safetensors(path.as_ref(), device)
53    }
54
55    pub fn from_safetensors(path: &Path, device: Device) -> Result<Self> {
56        validate_device(device)?;
57        if !path.is_file() {
58            anyhow::bail!("recognition weights not found: {path:?}");
59        }
60        Ok(Self {
61            weights: SafetensorsFile::open(path)?,
62            device,
63            cache: Mutex::new(HashMap::new()),
64            input_height: RECOGNITION_INPUT_HEIGHT,
65        })
66    }
67
68    pub fn from_model_dir(dir: &Path, device: Device) -> Result<Self> {
69        validate_device(device)?;
70        let path = prefer_safetensors_path(dir, HF_RECOGNITION_ST, HF_RECOGNITION_ST_FULL);
71        Self::from_safetensors(&path, device)
72    }
73
74    fn ensure_compiled(&self, compile_width: usize, max_lines_per_group: usize) -> Result<()> {
75        let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
76        if cache.contains_key(&compile_width) {
77            return Ok(());
78        }
79        let mut wm = self.weights.weight_map()?;
80        let (graph, params) = build_recognition_graph(
81            &mut wm,
82            RecognitionGraphConfig {
83                batch: max_lines_per_group,
84                width: compile_width,
85            },
86        )?;
87        let opts = compile_options_for_profile(&CompileProfile::encoder(), self.device);
88        let mut compiled = Session::new(self.device).compile_with(graph, &opts);
89        attach_built_params(&mut compiled, params, &[]);
90        cache.insert(compile_width, compiled);
91        Ok(())
92    }
93
94    /// Run recognition on an NCHW batch; returns `[batch, seq, classes]` log-probs.
95    pub fn run_batch_logits(&self, input: NdTensor<f32, 4>) -> Result<NdTensor<f32, 3>> {
96        let w = input.size(3);
97        let batch = input.size(0);
98        self.ensure_compiled(w, batch)?;
99        let input_vec: Vec<f32> = input.iter().copied().collect();
100        let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
101        let compiled = cache.get_mut(&w).unwrap();
102        let mut flat = compiled
103            .run(&[("image", input_vec.as_slice())])
104            .into_iter()
105            .next()
106            .ok_or_else(|| anyhow!("recognition returned no output"))?;
107        let seq_len = flat.len() / (batch * NUM_CLASSES);
108        log_softmax_last_axis(&mut flat, NUM_CLASSES);
109        Ok(NdTensor::from_data([batch, seq_len, NUM_CLASSES], flat))
110    }
111
112    pub fn prepare_input(
113        &self,
114        image: NdTensorView<f32, 3>,
115        line: &[RotatedRect],
116    ) -> NdTensor<f32, 2> {
117        let [_, img_height, img_width] = image.shape();
118        let page_rect = Rect::from_hw(img_height as i32, img_width as i32);
119        let line_rect = bounding_rect(line.iter())
120            .expect("line has no words")
121            .integral_bounding_rect();
122        let line_poly = Polygon::new(line_polygon(line));
123        let resized_width = resized_line_width(
124            line_rect.width(),
125            line_rect.height(),
126            self.input_height as i32,
127        );
128        prepare_text_line(
129            image,
130            page_rect,
131            &line_poly,
132            resized_width,
133            self.input_height as usize,
134        )
135    }
136
137    fn run_group(
138        &self,
139        image: &NdTensorView<f32, 3>,
140        lines: Vec<TextRecLine>,
141        page_rect: Rect,
142        compile_width: usize,
143        rec_img_height: u32,
144        decode_method: DecodeMethod,
145        alphabet_len: usize,
146        excluded_char_labels: Option<&[usize]>,
147    ) -> Result<Vec<LineRecResult>> {
148        self.ensure_compiled(compile_width, lines.len().max(1))?;
149        let rec_input = prepare_text_line_batch(
150            image,
151            &lines,
152            page_rect,
153            rec_img_height as usize,
154            compile_width,
155        );
156        let input: Vec<f32> = rec_input.iter().copied().collect();
157        let mut cache = self.cache.lock().map_err(|_| anyhow!("lock poisoned"))?;
158        let compiled = cache.get_mut(&compile_width).unwrap();
159        let outputs = compiled.run(&[("image", input.as_slice())]);
160        let flat_out = outputs
161            .into_iter()
162            .next()
163            .ok_or_else(|| anyhow!("recognition returned no output"))?;
164
165        let n_classes = alphabet_len + 1;
166        let seq_len = flat_out.len() / (lines.len() * n_classes);
167        let n = lines.len() * seq_len * n_classes;
168        let mut logits = flat_out[..n].to_vec();
169        log_softmax_last_axis(&mut logits, n_classes);
170        let mut rec_output = NdTensor::from_data([lines.len(), seq_len, n_classes], logits);
171
172        if n_classes != rec_output.size(2) {
173            return Err(anyhow!(
174                "recognition classes ({}) != alphabet+blank ({n_classes})",
175                rec_output.size(2)
176            ));
177        }
178        let ctc_input_len = rec_output.size(1);
179
180        Ok(lines
181            .into_iter()
182            .enumerate()
183            .map(|(group_line_index, line)| {
184                let mut input_seq_slice = rec_output.slice_mut([group_line_index]);
185                let input_seq =
186                    filter_excluded_char_labels(excluded_char_labels, &mut input_seq_slice);
187                let ctc_output = crate::ctc::decode(input_seq, decode_method);
188                LineRecResult {
189                    line,
190                    rec_input_len: compile_width,
191                    ctc_input_len,
192                    ctc_output,
193                }
194            })
195            .collect())
196    }
197
198    pub fn recognize_text_lines(
199        &self,
200        image: NdTensorView<f32, 3>,
201        lines: &[Vec<RotatedRect>],
202        decode_method: DecodeMethod,
203        alphabet: &str,
204        excluded_char_labels: Option<&[usize]>,
205    ) -> Result<Vec<Option<TextLine>>> {
206        let [_, img_height, img_width] = image.shape();
207        let page_rect = Rect::from_hw(img_height as i32, img_width as i32);
208        let rec_img_height = self.input_height;
209
210        let mut line_groups: HashMap<i32, Vec<TextRecLine>> = HashMap::new();
211        for (line_index, word_rects) in lines.iter().enumerate() {
212            let line_rect = bounding_rect(word_rects.iter())
213                .expect("line has no words")
214                .integral_bounding_rect();
215            let resized_width =
216                resized_line_width(line_rect.width(), line_rect.height(), rec_img_height as i32);
217            let group_width = resized_width.next_multiple_of(50);
218            line_groups
219                .entry(group_width as i32)
220                .or_default()
221                .push(TextRecLine {
222                    index: line_index,
223                    region: Polygon::new(line_polygon(word_rects)),
224                    resized_width,
225                });
226        }
227
228        let max_lines_per_group = 20;
229        let line_groups: Vec<(usize, Vec<TextRecLine>)> = line_groups
230            .into_iter()
231            .flat_map(|(group_width, lines)| {
232                lines
233                    .chunks(max_lines_per_group)
234                    .map(|chunk| (group_width as usize, chunk.to_vec()))
235                    .collect::<Vec<_>>()
236            })
237            .collect();
238
239        let alphabet_len = alphabet.chars().count();
240        for &(compile_width, ref chunk) in &line_groups {
241            self.ensure_compiled(compile_width, chunk.len().max(1))?;
242        }
243
244        let group_results: Result<Vec<Vec<LineRecResult>>> = line_groups
245            .into_iter()
246            .map(|(compile_width, lines)| {
247                self.run_group(
248                    &image,
249                    lines,
250                    page_rect,
251                    compile_width,
252                    rec_img_height,
253                    decode_method,
254                    alphabet_len,
255                    excluded_char_labels,
256                )
257            })
258            .collect();
259
260        let mut line_rec_results: Vec<LineRecResult> =
261            group_results?.into_iter().flatten().collect();
262        line_rec_results.sort_by_key(|r| r.line.index);
263        Ok(text_lines_from_recognition_results(
264            &line_rec_results,
265            alphabet,
266        ))
267    }
268}