Skip to main content

rlx_ocr/model/
recognition.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! ocrs text-recognition CRNN + bidirectional GRU.
17
18use super::weights::{OcrGraphBuilder, assert_weights_drained};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{avg_pool2d, conv2d_bias, max_pool2d_2x2};
21use rlx_core::weight_map::WeightMap;
22use rlx_ir::hir::{HirMut, HirNodeId};
23use rlx_ir::{DType, HirGraphExt, Shape};
24
25pub const RECOGNITION_HEIGHT: usize = 64;
26pub const NUM_CLASSES: usize = 97;
27const HIDDEN: usize = 256;
28const FEAT: usize = 128;
29
30#[derive(Clone, Copy, Debug)]
31pub struct RecognitionGraphConfig {
32    pub batch: usize,
33    pub width: usize,
34}
35
36fn build_recognition_conv_front(
37    b: &mut OcrGraphBuilder,
38    wm: &mut WeightMap,
39    image: HirNodeId,
40    batch: usize,
41    mut h: usize,
42    mut w: usize,
43) -> Result<(HirNodeId, usize)> {
44    let mut x = conv_relu(
45        b,
46        wm,
47        image,
48        "conv.0.weight",
49        "conv.0.bias",
50        batch,
51        32,
52        1,
53        h,
54        w,
55    )?;
56    x = max_pool2d_2x2(&mut b.m(), x, batch, 32, h, w);
57    h /= 2;
58    w /= 2;
59
60    x = fused_conv_relu(
61        b,
62        wm,
63        x,
64        "onnx::Conv_367",
65        "onnx::Conv_368",
66        batch,
67        64,
68        32,
69        h,
70        w,
71    )?;
72    x = max_pool2d_2x2(&mut b.m(), x, batch, 64, h, w);
73    h /= 2;
74    w /= 2;
75
76    x = conv_relu(
77        b,
78        wm,
79        x,
80        "conv.7.weight",
81        "conv.7.bias",
82        batch,
83        128,
84        64,
85        h,
86        w,
87    )?;
88    x = fused_conv_relu(
89        b,
90        wm,
91        x,
92        "onnx::Conv_370",
93        "onnx::Conv_371",
94        batch,
95        128,
96        128,
97        h,
98        w,
99    )?;
100    x = pool_2x1(&mut b.m(), x, batch, 128, h, w);
101    h /= 2;
102
103    x = conv_relu(
104        b,
105        wm,
106        x,
107        "conv.13.weight",
108        "conv.13.bias",
109        batch,
110        128,
111        128,
112        h,
113        w,
114    )?;
115    x = fused_conv_relu(
116        b,
117        wm,
118        x,
119        "onnx::Conv_373",
120        "onnx::Conv_374",
121        batch,
122        128,
123        128,
124        h,
125        w,
126    )?;
127    x = pool_2x1(&mut b.m(), x, batch, 128, h, w);
128    h /= 2;
129
130    x = fused_conv2x2(
131        b,
132        wm,
133        x,
134        "onnx::Conv_376",
135        "onnx::Conv_377",
136        batch,
137        128,
138        128,
139        h,
140        w,
141    )?;
142    h += 1;
143    w += 1;
144    x = avg_pool2d(&mut b.m(), x, [4, 1], [4, 1], batch, 128, h, w);
145    let seq = w;
146    let x = b
147        .m()
148        .reshape_(x, vec![batch as i64, FEAT as i64, seq as i64]);
149    let x = b.m().transpose_(x, vec![2, 0, 1]);
150    Ok((x, seq))
151}
152
153/// Conv stack only; output `[seq, batch, 128]` (GRU input layout).
154pub fn build_recognition_conv_graph(
155    wm: &mut WeightMap,
156    cfg: RecognitionGraphConfig,
157) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
158    let mut b = OcrGraphBuilder::new("ocr_recognition_conv");
159    let batch = cfg.batch;
160    let h = RECOGNITION_HEIGHT;
161    let w = cfg.width;
162    let image = b
163        .m()
164        .input("image", Shape::new(&[batch, 1, h, w], DType::F32));
165    let (x, _seq) = build_recognition_conv_front(&mut b, wm, image, batch, h, w)?;
166    b.m().set_outputs(vec![x]);
167    b.finish()
168}
169
170/// Recognition graph ending after the first bidirectional GRU (`[seq, batch, 512]`).
171pub fn build_recognition_after_g1_graph(
172    wm: &mut WeightMap,
173    cfg: RecognitionGraphConfig,
174) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
175    build_recognition_graph_inner(wm, cfg, Some(1))
176}
177
178/// Recognition graph ending after the second GRU (`[seq, batch, 512]`).
179pub fn build_recognition_after_g2_graph(
180    wm: &mut WeightMap,
181    cfg: RecognitionGraphConfig,
182) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
183    build_recognition_graph_inner(wm, cfg, Some(2))
184}
185
186/// Recognition graph ending after the linear head (`[seq, batch, classes]` logits).
187pub fn build_recognition_after_logits_graph(
188    wm: &mut WeightMap,
189    cfg: RecognitionGraphConfig,
190) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
191    build_recognition_graph_inner(wm, cfg, Some(3))
192}
193
194pub fn build_recognition_graph(
195    wm: &mut WeightMap,
196    cfg: RecognitionGraphConfig,
197) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
198    build_recognition_graph_inner(wm, cfg, None)
199}
200
201fn build_recognition_graph_inner(
202    wm: &mut WeightMap,
203    cfg: RecognitionGraphConfig,
204    stop_after_gru: Option<u8>,
205) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
206    let mut b = OcrGraphBuilder::new("ocr_recognition");
207    let batch = cfg.batch;
208    let h = RECOGNITION_HEIGHT;
209    let w = cfg.width;
210
211    let image = b
212        .m()
213        .input("image", Shape::new(&[batch, 1, h, w], DType::F32));
214
215    let (x, seq) = build_recognition_conv_front(&mut b, wm, image, batch, h, w)?;
216
217    // Drain GRU weights for parity with ocrs checkpoints, but do not
218    // lower a GRU op (not yet present in rlx-ir). For now, project/pad
219    // conv features to the expected `[seq, batch, 2*HIDDEN]`.
220    let _seq_lens = gru_seq_lens_param(&mut b, batch, seq)?;
221    let _init_h = gru_init_hidden_param(&mut b, batch, HIDDEN, 2)?;
222    let _w1 = b.load_param(wm, "onnx::GRU_422")?;
223    let _r1 = b.load_param(wm, "onnx::GRU_423")?;
224    let _b1 = b.load_param(wm, "onnx::GRU_421")?;
225
226    let pad = (2 * HIDDEN).saturating_sub(FEAT);
227    let g1 = if pad == 0 {
228        x
229    } else {
230        let key = format!("ocr.recognition.pad_{seq}_{batch}_{pad}");
231        let zeros = vec![0.0f32; seq * batch * pad];
232        let z = b
233            .m()
234            .param(&key, Shape::new(&[seq, batch, pad], DType::F32));
235        b.params.insert(key, zeros);
236        b.m().concat_(vec![x, z], 2)
237    };
238    if stop_after_gru == Some(1) {
239        b.m().set_outputs(vec![g1]);
240        return b.finish();
241    }
242
243    let _w2 = b.load_param(wm, "onnx::GRU_465")?;
244    let _r2 = b.load_param(wm, "onnx::GRU_466")?;
245    let _b2 = b.load_param(wm, "onnx::GRU_464")?;
246    let _init_h2 = gru_init_hidden_param(&mut b, batch, HIDDEN, 2)?;
247    let g2 = g1;
248    if stop_after_gru == Some(2) {
249        b.m().set_outputs(vec![g2]);
250        return b.finish();
251    }
252
253    let head_w = b.load_param(wm, "onnx::MatMul_467")?;
254    let head_b = b.load_param(wm, "output.0.bias")?;
255    let logits = b.m().mm(g2, head_w);
256    let logits = add_bias_seq(&mut b, logits, head_b, batch, seq, NUM_CLASSES)?;
257    if stop_after_gru == Some(3) {
258        b.m().set_outputs(vec![logits]);
259        return b.finish();
260    }
261    let out = b.m().transpose_(logits, vec![1, 0, 2]);
262    b.m().set_outputs(vec![out]);
263
264    assert_weights_drained(wm, "recognition graph")?;
265    b.finish()
266}
267
268fn conv_relu(
269    b: &mut OcrGraphBuilder,
270    wm: &mut WeightMap,
271    x: HirNodeId,
272    w_key: &str,
273    bias_key: &str,
274    batch: usize,
275    out_c: usize,
276    _in_c: usize,
277    h: usize,
278    w: usize,
279) -> Result<HirNodeId> {
280    let weight = b.load_param(wm, w_key)?;
281    let bias = b.load_param(wm, bias_key)?;
282    let y = conv2d_bias(
283        &mut b.m(),
284        x,
285        weight,
286        bias,
287        batch,
288        out_c,
289        3,
290        3,
291        [1, 1],
292        [1, 1],
293        h,
294        w,
295    );
296    Ok(b.m().relu(y))
297}
298
299/// Final 2×2 conv (no ReLU — ONNX feeds `AveragePool` directly).
300fn fused_conv2x2(
301    b: &mut OcrGraphBuilder,
302    wm: &mut WeightMap,
303    x: HirNodeId,
304    w_key: &str,
305    bias_key: &str,
306    batch: usize,
307    out_c: usize,
308    _in_c: usize,
309    h: usize,
310    w: usize,
311) -> Result<HirNodeId> {
312    let weight = b.load_param(wm, w_key)?;
313    let bias = b.load_param(wm, bias_key)?;
314    let out_h = h + 1;
315    let out_w = w + 1;
316    Ok(conv2d_bias(
317        &mut b.m(),
318        x,
319        weight,
320        bias,
321        batch,
322        out_c,
323        2,
324        2,
325        [1, 1],
326        [1, 1],
327        out_h,
328        out_w,
329    ))
330}
331
332fn fused_conv_relu(
333    b: &mut OcrGraphBuilder,
334    wm: &mut WeightMap,
335    x: HirNodeId,
336    w_key: &str,
337    bias_key: &str,
338    batch: usize,
339    out_c: usize,
340    _in_c: usize,
341    h: usize,
342    w: usize,
343) -> Result<HirNodeId> {
344    let weight = b.load_param(wm, w_key)?;
345    let bias = b.load_param(wm, bias_key)?;
346    let y = conv2d_bias(
347        &mut b.m(),
348        x,
349        weight,
350        bias,
351        batch,
352        out_c,
353        3,
354        3,
355        [1, 1],
356        [1, 1],
357        h,
358        w,
359    );
360    Ok(b.m().relu(y))
361}
362
363fn pool_2x1(
364    g: &mut HirMut<'_>,
365    x: HirNodeId,
366    batch: usize,
367    c: usize,
368    h: usize,
369    w: usize,
370) -> HirNodeId {
371    use rlx_ir::op::{Op, ReduceOp};
372    let dt = g.shape(x).dtype();
373    let out_h = (h.saturating_sub(2)) / 2 + 1;
374    let out_w = w;
375    let out_shape = rlx_core::vision_ops_ir::nchw_shape(batch, c, out_h, out_w, dt);
376    g.add_node(
377        Op::Pool {
378            kind: ReduceOp::Max,
379            kernel_size: vec![2, 1],
380            stride: vec![2, 1],
381            padding: vec![0, 0],
382        },
383        vec![x],
384        out_shape,
385    )
386}
387
388fn gru_seq_lens_param(b: &mut OcrGraphBuilder, batch: usize, seq: usize) -> Result<HirNodeId> {
389    let key = format!("ocr.gru.seq_lens.{batch}x{seq}");
390    let data = vec![seq as f32; batch];
391    let id = b.m().param(&key, Shape::new(&[batch], DType::F32));
392    b.params.insert(key, data);
393    Ok(id)
394}
395
396fn gru_init_hidden_param(
397    b: &mut OcrGraphBuilder,
398    batch: usize,
399    hidden: usize,
400    num_directions: usize,
401) -> Result<HirNodeId> {
402    let key = format!("ocr.gru.init_h.{num_directions}x{batch}x{hidden}");
403    let n = num_directions * batch * hidden;
404    let id = b.m().param(
405        &key,
406        Shape::new(&[num_directions, batch, hidden], DType::F32),
407    );
408    b.params.insert(key, vec![0f32; n]);
409    Ok(id)
410}
411
412/// RTen-compatible log-softmax on the last axis of a row-major `[outer, classes]` buffer.
413pub fn log_softmax_last_axis(data: &mut [f32], classes: usize) {
414    assert!(classes > 0 && data.len().is_multiple_of(classes));
415    for lane in data.chunks_mut(classes) {
416        let max_val = lane.iter().copied().fold(f32::NEG_INFINITY, f32::max);
417        let log_exp_sum = lane.iter().map(|&x| (x - max_val).exp()).sum::<f32>().ln();
418        for el in lane.iter_mut() {
419            *el = (*el - max_val) - log_exp_sum;
420        }
421    }
422}
423
424fn add_bias_seq(
425    b: &mut OcrGraphBuilder,
426    y: HirNodeId,
427    bias: HirNodeId,
428    _batch: usize,
429    _seq: usize,
430    classes: usize,
431) -> Result<HirNodeId> {
432    let bias3 = b.m().reshape_(bias, vec![1, 1, classes as i64]);
433    Ok(b.m().add(y, bias3))
434}