Skip to main content

rlx_ocr/model/
detection.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! ocrs text-detection U-Net ([`ocrs_models::DetectionModel`](https://github.com/robertknight/ocrs-models)).
17
18use super::weights::{
19    DET_DW_KEYS, DET_ONNX_PW, OcrGraphBuilder, assert_weights_drained, detection_input_hw,
20};
21use anyhow::Result;
22use rlx_core::vision_ops_ir::{
23    conv_transpose2d_k3s2_bias_trim, conv2d_bias, conv2d_bias_groups, max_pool2d_2x2, sigmoid_nchw,
24};
25use rlx_core::weight_map::WeightMap;
26use rlx_ir::hir::HirNodeId;
27use rlx_ir::{DType, HirGraphExt, Shape};
28
29/// Fixed compile-time input size for the HF detection checkpoint (override via `OCR_DETECTION_HW`).
30#[allow(dead_code)]
31pub const DEFAULT_DETECTION_INPUT_HW: (usize, usize) = (800, 600);
32
33#[derive(Clone, Copy, Debug)]
34pub struct DetectionGraphConfig {
35    pub batch: usize,
36    pub height: usize,
37    pub width: usize,
38}
39
40impl Default for DetectionGraphConfig {
41    fn default() -> Self {
42        let (height, width) = detection_input_hw();
43        Self {
44            batch: 1,
45            height,
46            width,
47        }
48    }
49}
50
51/// Channel widths at each U-Net level (matches `ocrs_models`).
52const DEPTH_SCALE: [usize; 7] = [8, 16, 32, 32, 64, 128, 256];
53
54pub fn build_detection_graph(
55    wm: &mut WeightMap,
56    cfg: DetectionGraphConfig,
57) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
58    let mut b = OcrGraphBuilder::new("ocr_detection");
59    let f = DType::F32;
60    let batch = cfg.batch;
61    let mut h = cfg.height;
62    let mut w = cfg.width;
63
64    let image = b.m().input("image", Shape::new(&[batch, 1, h, w], f));
65
66    let mut block = 0usize;
67    let mut x = double_conv(
68        &mut b,
69        wm,
70        image,
71        &mut block,
72        1,
73        DEPTH_SCALE[0],
74        batch,
75        h,
76        w,
77    )?;
78    let in_conv_skip = (x, h, w);
79
80    let mut x_down: Vec<(HirNodeId, usize, usize)> = Vec::new();
81    for level in 0..DEPTH_SCALE.len() - 1 {
82        let in_c = DEPTH_SCALE[level];
83        let out_c = DEPTH_SCALE[level + 1];
84        x = double_conv(&mut b, wm, x, &mut block, in_c, out_c, batch, h, w)?;
85        x = max_pool2d_2x2(&mut b.m(), x, batch, out_c, h, w);
86        h /= 2;
87        w /= 2;
88        x_down.push((x, h, w));
89    }
90
91    let mut x_up = x;
92    let mut up_h = h;
93    let mut up_w = w;
94    for up_idx in (0..DEPTH_SCALE.len() - 1).rev() {
95        let out_c = DEPTH_SCALE[up_idx];
96        let cross_c = DEPTH_SCALE[up_idx];
97        let (skip, skip_h, skip_w) = if up_idx == 0 {
98            (in_conv_skip.0, in_conv_skip.1, in_conv_skip.2)
99        } else {
100            let (skip_node, sh, sw) = x_down[up_idx - 1];
101            (skip_node, sh, sw)
102        };
103
104        let up_w_key = format!("up.{up_idx}.up.weight");
105        let up_b_key = format!("up.{up_idx}.up.bias");
106        let up_weight = b.load_param(wm, &up_w_key)?;
107        let up_bias = b.load_param(wm, &up_b_key)?;
108        let upscaled = conv_transpose2d_k3s2_bias_trim(
109            &mut b.m(),
110            x_up,
111            up_weight,
112            up_bias,
113            batch,
114            out_c,
115            up_h,
116            up_w,
117            skip_h,
118            skip_w,
119        );
120        up_h = skip_h;
121        up_w = skip_w;
122
123        let cat = b.m().concat_(vec![upscaled, skip], 1);
124        x_up = double_conv(
125            &mut b,
126            wm,
127            cat,
128            &mut block,
129            out_c + cross_c,
130            out_c,
131            batch,
132            up_h,
133            up_w,
134        )?;
135    }
136
137    let out_w = b.load_param(wm, "out_conv.0.weight")?;
138    let out_b = b.load_param(wm, "out_conv.0.bias")?;
139    let logits = conv2d_bias(
140        &mut b.m(),
141        x_up,
142        out_w,
143        out_b,
144        batch,
145        1,
146        1,
147        1,
148        [1, 1],
149        [0, 0],
150        up_h,
151        up_w,
152    );
153    let mask = sigmoid_nchw(&mut b.m(), logits);
154    b.m().set_outputs(vec![mask]);
155
156    assert_weights_drained(wm, "detection graph")?;
157    b.finish()
158}
159
160fn double_conv(
161    b: &mut OcrGraphBuilder,
162    wm: &mut WeightMap,
163    mut x: HirNodeId,
164    block: &mut usize,
165    in_c: usize,
166    out_c: usize,
167    batch: usize,
168    h: usize,
169    w: usize,
170) -> Result<HirNodeId> {
171    let (pw0_w, pw0_b) = DET_ONNX_PW[*block];
172    // `DepthwiseConv`: dw 3×3 → pw 1×1 (+ fused BN in onnx keys) → ReLU (once).
173    x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], in_c, batch, h, w)?;
174    x = pointwise_relu(b, wm, x, pw0_w, pw0_b, in_c, out_c, batch, h, w)?;
175    *block += 1;
176    let (pw1_w, pw1_b) = DET_ONNX_PW[*block];
177    x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], out_c, batch, h, w)?;
178    x = pointwise_relu(b, wm, x, pw1_w, pw1_b, out_c, out_c, batch, h, w)?;
179    *block += 1;
180    Ok(x)
181}
182
183fn depthwise_conv(
184    b: &mut OcrGraphBuilder,
185    wm: &mut WeightMap,
186    x: HirNodeId,
187    dw_key: &str,
188    channels: usize,
189    batch: usize,
190    h: usize,
191    w: usize,
192) -> Result<HirNodeId> {
193    let weight = b.load_param(wm, dw_key)?;
194    let bias = b.zero_bias(channels)?;
195    Ok(conv2d_bias_groups(
196        &mut b.m(),
197        x,
198        weight,
199        bias,
200        batch,
201        channels,
202        3,
203        3,
204        [1, 1],
205        [1, 1],
206        channels,
207        h,
208        w,
209    ))
210}
211
212fn pointwise_relu(
213    b: &mut OcrGraphBuilder,
214    wm: &mut WeightMap,
215    x: HirNodeId,
216    w_key: &str,
217    b_key: &str,
218    _in_c: usize,
219    out_c: usize,
220    batch: usize,
221    h: usize,
222    w: usize,
223) -> Result<HirNodeId> {
224    let weight = b.load_param(wm, w_key)?;
225    let bias = b.load_param(wm, b_key)?;
226    let y = conv2d_bias(
227        &mut b.m(),
228        x,
229        weight,
230        bias,
231        batch,
232        out_c,
233        1,
234        1,
235        [1, 1],
236        [0, 0],
237        h,
238        w,
239    );
240    Ok(b.m().relu(y))
241}