Skip to main content

rlx_ocr/recognition/
line_batch.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared line cropping, batching, and CTC post-processing for recognition backends.
17
18use crate::ctc::CtcHypothesis;
19use crate::geom::{downwards_line, leftmost_edge, rightmost_edge};
20use crate::preprocess::BLACK_VALUE;
21use crate::text::{TextChar, TextLine};
22#[cfg(feature = "tensor-ops")]
23use rten::FloatOperators;
24use rten_imageproc::{BoundingRect, Line, Point, Polygon, Rect, RotatedRect};
25use rten_tensor::prelude::*;
26use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};
27
28#[derive(Clone)]
29pub struct TextRecLine {
30    pub index: usize,
31    pub region: Polygon,
32    pub resized_width: u32,
33}
34
35pub struct LineRecResult {
36    pub line: TextRecLine,
37    pub rec_input_len: usize,
38    pub ctc_input_len: usize,
39    pub ctc_output: CtcHypothesis,
40}
41
42pub fn resized_line_width(orig_width: i32, orig_height: i32, height: i32) -> u32 {
43    let min_width = 10.;
44    let max_width = 2400.;
45    let aspect_ratio = orig_width as f32 / orig_height as f32;
46    (height as f32 * aspect_ratio).clamp(min_width, max_width) as u32
47}
48
49pub fn line_polygon(words: &[RotatedRect]) -> Vec<Point> {
50    let mut polygon = Vec::new();
51    let floor_point = |p: rten_imageproc::PointF| Point::from_yx(p.y as i32, p.x as i32);
52
53    for word_rect in words.iter() {
54        let (left, right) = (
55            downwards_line(leftmost_edge(word_rect)),
56            downwards_line(rightmost_edge(word_rect)),
57        );
58        polygon.push(floor_point(left.start));
59        polygon.push(floor_point(right.start));
60    }
61    for word_rect in words.iter().rev() {
62        let (left, right) = (
63            downwards_line(leftmost_edge(word_rect)),
64            downwards_line(rightmost_edge(word_rect)),
65        );
66        polygon.push(floor_point(right.end));
67        polygon.push(floor_point(left.end));
68    }
69    polygon
70}
71
72pub fn prepare_text_line(
73    image: NdTensorView<f32, 3>,
74    page_rect: Rect,
75    line_region: &Polygon,
76    resized_width: u32,
77    output_height: usize,
78) -> NdTensor<f32, 2> {
79    let page_index_rect = page_rect.adjust_tlbr(0, 0, -1, -1);
80    let grey_chan = image.slice([0]);
81    let line_rect = line_region.bounding_rect();
82    let mut line_img = NdTensor::full(
83        [line_rect.height() as usize, line_rect.width() as usize],
84        BLACK_VALUE,
85    );
86
87    for in_p in line_region.fill_iter() {
88        let out_p = Point::from_yx(in_p.y - line_rect.top(), in_p.x - line_rect.left());
89        if !page_index_rect.contains_point(in_p) || !page_index_rect.contains_point(out_p) {
90            continue;
91        }
92        line_img[[out_p.y as usize, out_p.x as usize]] =
93            grey_chan[[in_p.y as usize, in_p.x as usize]];
94    }
95
96    let resized_line_img = line_img
97        .reshaped([1, 1, line_img.size(0), line_img.size(1)])
98        .resize_image([output_height, resized_width as usize])
99        .unwrap();
100    let out_shape = [resized_line_img.size(2), resized_line_img.size(3)];
101    resized_line_img.into_shape(out_shape)
102}
103
104pub fn prepare_text_line_batch(
105    image: &NdTensorView<f32, 3>,
106    lines: &[TextRecLine],
107    page_rect: Rect,
108    output_height: usize,
109    output_width: usize,
110) -> NdTensor<f32, 4> {
111    let mut output = NdTensor::full([lines.len(), 1, output_height, output_width], BLACK_VALUE);
112    for (group_line_index, line) in lines.iter().enumerate() {
113        let resized_line_img = prepare_text_line(
114            image.view(),
115            page_rect,
116            &line.region,
117            line.resized_width,
118            output_height,
119        );
120        output
121            .slice_mut((group_line_index, 0, .., ..(line.resized_width as usize)))
122            .copy_from(&resized_line_img);
123    }
124    output
125}
126
127fn polygon_slice_bounding_rect(poly: Polygon, min_x: i32, max_x: i32) -> Option<Rect> {
128    poly.edges()
129        .filter_map(|e| {
130            let e = e.rightwards();
131            if (e.start.x < min_x && e.end.x < min_x) || (e.start.x > max_x && e.end.x > max_x) {
132                return None;
133            }
134            let trunc_edge_start = e
135                .to_f32()
136                .y_for_x(min_x as f32)
137                .map_or(e.start, |y| Point::from_yx(y.round() as i32, min_x));
138            let trunc_edge_end = e
139                .to_f32()
140                .y_for_x(max_x as f32)
141                .map_or(e.end, |y| Point::from_yx(y.round() as i32, max_x));
142            Some(Line::from_endpoints(trunc_edge_start, trunc_edge_end))
143        })
144        .fold(None, |bounding_rect, e| {
145            let edge_br = e.bounding_rect();
146            bounding_rect.map(|br| br.union(edge_br)).or(Some(edge_br))
147        })
148}
149
150pub fn text_lines_from_recognition_results(
151    results: &[LineRecResult],
152    alphabet: &str,
153) -> Vec<Option<TextLine>> {
154    results
155        .iter()
156        .map(|result| {
157            let line_rect = result.line.region.bounding_rect();
158            let x_scale_factor = (line_rect.width() as f32) / (result.line.resized_width as f32);
159            let downsample_factor =
160                (result.rec_input_len as f32 / result.ctc_input_len as f32).round() as u32;
161
162            let steps = result.ctc_output.steps();
163            let text_line: Vec<TextChar> = steps
164                .iter()
165                .enumerate()
166                .filter_map(|(i, step)| {
167                    let start_x = step.pos * downsample_factor;
168                    let end_x = if let Some(next_step) = steps.get(i + 1) {
169                        next_step.pos * downsample_factor
170                    } else {
171                        result.line.resized_width
172                    };
173                    let [start_x, end_x] = [start_x, end_x]
174                        .map(|x| line_rect.left() + (x as f32 * x_scale_factor) as i32);
175                    if start_x >= line_rect.right() {
176                        return None;
177                    }
178                    let ch = alphabet
179                        .chars()
180                        .nth((step.label.saturating_sub(1)) as usize)
181                        .unwrap_or('?');
182                    Some(TextChar {
183                        char: ch,
184                        rect: polygon_slice_bounding_rect(
185                            result.line.region.clone(),
186                            start_x,
187                            end_x,
188                        )
189                        .expect("invalid X coords"),
190                    })
191                })
192                .collect();
193            if text_line.is_empty() {
194                None
195            } else {
196                Some(TextLine::new(text_line))
197            }
198        })
199        .collect()
200}
201
202pub fn filter_excluded_char_labels<'a>(
203    excluded_char_labels: Option<&[usize]>,
204    input_seq_slice: &'a mut NdTensorViewMut<f32, 2>,
205) -> NdTensorView<'a, f32, 2> {
206    if let Some(excluded_char_labels) = excluded_char_labels {
207        for row in 0..input_seq_slice.size(0) {
208            for &excluded_char_label in excluded_char_labels.iter() {
209                input_seq_slice[[row, excluded_char_label]] = f32::NEG_INFINITY;
210            }
211        }
212    }
213    input_seq_slice.view()
214}
215
216pub fn bounding_rect<'a, I: Iterator<Item = &'a RotatedRect>>(
217    rects: I,
218) -> Option<rten_imageproc::RectF> {
219    rects.fold(None, |br, r| match br {
220        Some(br) => Some(br.union(r.bounding_rect())),
221        None => Some(r.bounding_rect()),
222    })
223}