1use super::weights::{
19 DET_DW_KEYS, DET_ONNX_PW, OcrGraphBuilder, assert_weights_drained, detection_input_hw,
20};
21use anyhow::Result;
22use rlx_core::vision_ops_ir::{
23 conv_transpose2d_k3s2_bias_trim, conv2d_bias, conv2d_bias_groups, max_pool2d_2x2, sigmoid_nchw,
24};
25use rlx_core::weight_map::WeightMap;
26use rlx_ir::hir::HirNodeId;
27use rlx_ir::{DType, HirGraphExt, Shape};
28
29#[allow(dead_code)]
31pub const DEFAULT_DETECTION_INPUT_HW: (usize, usize) = (800, 600);
32
33#[derive(Clone, Copy, Debug)]
34pub struct DetectionGraphConfig {
35 pub batch: usize,
36 pub height: usize,
37 pub width: usize,
38}
39
40impl Default for DetectionGraphConfig {
41 fn default() -> Self {
42 let (height, width) = detection_input_hw();
43 Self {
44 batch: 1,
45 height,
46 width,
47 }
48 }
49}
50
51const DEPTH_SCALE: [usize; 7] = [8, 16, 32, 32, 64, 128, 256];
53
54pub fn build_detection_graph(
55 wm: &mut WeightMap,
56 cfg: DetectionGraphConfig,
57) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
58 let mut b = OcrGraphBuilder::new("ocr_detection");
59 let f = DType::F32;
60 let batch = cfg.batch;
61 let mut h = cfg.height;
62 let mut w = cfg.width;
63
64 let image = b.m().input("image", Shape::new(&[batch, 1, h, w], f));
65
66 let mut block = 0usize;
67 let mut x = double_conv(
68 &mut b,
69 wm,
70 image,
71 &mut block,
72 1,
73 DEPTH_SCALE[0],
74 batch,
75 h,
76 w,
77 )?;
78 let in_conv_skip = (x, h, w);
79
80 let mut x_down: Vec<(HirNodeId, usize, usize)> = Vec::new();
81 for level in 0..DEPTH_SCALE.len() - 1 {
82 let in_c = DEPTH_SCALE[level];
83 let out_c = DEPTH_SCALE[level + 1];
84 x = double_conv(&mut b, wm, x, &mut block, in_c, out_c, batch, h, w)?;
85 x = max_pool2d_2x2(&mut b.m(), x, batch, out_c, h, w);
86 h /= 2;
87 w /= 2;
88 x_down.push((x, h, w));
89 }
90
91 let mut x_up = x;
92 let mut up_h = h;
93 let mut up_w = w;
94 for up_idx in (0..DEPTH_SCALE.len() - 1).rev() {
95 let out_c = DEPTH_SCALE[up_idx];
96 let cross_c = DEPTH_SCALE[up_idx];
97 let (skip, skip_h, skip_w) = if up_idx == 0 {
98 (in_conv_skip.0, in_conv_skip.1, in_conv_skip.2)
99 } else {
100 let (skip_node, sh, sw) = x_down[up_idx - 1];
101 (skip_node, sh, sw)
102 };
103
104 let up_w_key = format!("up.{up_idx}.up.weight");
105 let up_b_key = format!("up.{up_idx}.up.bias");
106 let up_weight = b.load_param(wm, &up_w_key)?;
107 let up_bias = b.load_param(wm, &up_b_key)?;
108 let upscaled = conv_transpose2d_k3s2_bias_trim(
109 &mut b.m(),
110 x_up,
111 up_weight,
112 up_bias,
113 batch,
114 out_c,
115 up_h,
116 up_w,
117 skip_h,
118 skip_w,
119 );
120 up_h = skip_h;
121 up_w = skip_w;
122
123 let cat = b.m().concat_(vec![upscaled, skip], 1);
124 x_up = double_conv(
125 &mut b,
126 wm,
127 cat,
128 &mut block,
129 out_c + cross_c,
130 out_c,
131 batch,
132 up_h,
133 up_w,
134 )?;
135 }
136
137 let out_w = b.load_param(wm, "out_conv.0.weight")?;
138 let out_b = b.load_param(wm, "out_conv.0.bias")?;
139 let logits = conv2d_bias(
140 &mut b.m(),
141 x_up,
142 out_w,
143 out_b,
144 batch,
145 1,
146 1,
147 1,
148 [1, 1],
149 [0, 0],
150 up_h,
151 up_w,
152 );
153 let mask = sigmoid_nchw(&mut b.m(), logits);
154 b.m().set_outputs(vec![mask]);
155
156 assert_weights_drained(wm, "detection graph")?;
157 b.finish()
158}
159
160fn double_conv(
161 b: &mut OcrGraphBuilder,
162 wm: &mut WeightMap,
163 mut x: HirNodeId,
164 block: &mut usize,
165 in_c: usize,
166 out_c: usize,
167 batch: usize,
168 h: usize,
169 w: usize,
170) -> Result<HirNodeId> {
171 let (pw0_w, pw0_b) = DET_ONNX_PW[*block];
172 x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], in_c, batch, h, w)?;
174 x = pointwise_relu(b, wm, x, pw0_w, pw0_b, in_c, out_c, batch, h, w)?;
175 *block += 1;
176 let (pw1_w, pw1_b) = DET_ONNX_PW[*block];
177 x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], out_c, batch, h, w)?;
178 x = pointwise_relu(b, wm, x, pw1_w, pw1_b, out_c, out_c, batch, h, w)?;
179 *block += 1;
180 Ok(x)
181}
182
183fn depthwise_conv(
184 b: &mut OcrGraphBuilder,
185 wm: &mut WeightMap,
186 x: HirNodeId,
187 dw_key: &str,
188 channels: usize,
189 batch: usize,
190 h: usize,
191 w: usize,
192) -> Result<HirNodeId> {
193 let weight = b.load_param(wm, dw_key)?;
194 let bias = b.zero_bias(channels)?;
195 Ok(conv2d_bias_groups(
196 &mut b.m(),
197 x,
198 weight,
199 bias,
200 batch,
201 channels,
202 3,
203 3,
204 [1, 1],
205 [1, 1],
206 channels,
207 h,
208 w,
209 ))
210}
211
212fn pointwise_relu(
213 b: &mut OcrGraphBuilder,
214 wm: &mut WeightMap,
215 x: HirNodeId,
216 w_key: &str,
217 b_key: &str,
218 _in_c: usize,
219 out_c: usize,
220 batch: usize,
221 h: usize,
222 w: usize,
223) -> Result<HirNodeId> {
224 let weight = b.load_param(wm, w_key)?;
225 let bias = b.load_param(wm, b_key)?;
226 let y = conv2d_bias(
227 &mut b.m(),
228 x,
229 weight,
230 bias,
231 batch,
232 out_c,
233 1,
234 1,
235 [1, 1],
236 [0, 0],
237 h,
238 w,
239 );
240 Ok(b.m().relu(y))
241}