1use super::weights::{OcrGraphBuilder, assert_weights_drained};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{avg_pool2d, conv2d_bias, max_pool2d_2x2};
21use rlx_core::weight_map::WeightMap;
22use rlx_ir::hir::{HirMut, HirNodeId};
23use rlx_ir::{DType, HirGraphExt, Shape};
24
25pub const RECOGNITION_HEIGHT: usize = 64;
26pub const NUM_CLASSES: usize = 97;
27const HIDDEN: usize = 256;
28const FEAT: usize = 128;
29
30#[derive(Clone, Copy, Debug)]
31pub struct RecognitionGraphConfig {
32 pub batch: usize,
33 pub width: usize,
34}
35
36fn build_recognition_conv_front(
37 b: &mut OcrGraphBuilder,
38 wm: &mut WeightMap,
39 image: HirNodeId,
40 batch: usize,
41 mut h: usize,
42 mut w: usize,
43) -> Result<(HirNodeId, usize)> {
44 let mut x = conv_relu(
45 b,
46 wm,
47 image,
48 "conv.0.weight",
49 "conv.0.bias",
50 batch,
51 32,
52 1,
53 h,
54 w,
55 )?;
56 x = max_pool2d_2x2(&mut b.m(), x, batch, 32, h, w);
57 h /= 2;
58 w /= 2;
59
60 x = fused_conv_relu(
61 b,
62 wm,
63 x,
64 "onnx::Conv_367",
65 "onnx::Conv_368",
66 batch,
67 64,
68 32,
69 h,
70 w,
71 )?;
72 x = max_pool2d_2x2(&mut b.m(), x, batch, 64, h, w);
73 h /= 2;
74 w /= 2;
75
76 x = conv_relu(
77 b,
78 wm,
79 x,
80 "conv.7.weight",
81 "conv.7.bias",
82 batch,
83 128,
84 64,
85 h,
86 w,
87 )?;
88 x = fused_conv_relu(
89 b,
90 wm,
91 x,
92 "onnx::Conv_370",
93 "onnx::Conv_371",
94 batch,
95 128,
96 128,
97 h,
98 w,
99 )?;
100 x = pool_2x1(&mut b.m(), x, batch, 128, h, w);
101 h /= 2;
102
103 x = conv_relu(
104 b,
105 wm,
106 x,
107 "conv.13.weight",
108 "conv.13.bias",
109 batch,
110 128,
111 128,
112 h,
113 w,
114 )?;
115 x = fused_conv_relu(
116 b,
117 wm,
118 x,
119 "onnx::Conv_373",
120 "onnx::Conv_374",
121 batch,
122 128,
123 128,
124 h,
125 w,
126 )?;
127 x = pool_2x1(&mut b.m(), x, batch, 128, h, w);
128 h /= 2;
129
130 x = fused_conv2x2(
131 b,
132 wm,
133 x,
134 "onnx::Conv_376",
135 "onnx::Conv_377",
136 batch,
137 128,
138 128,
139 h,
140 w,
141 )?;
142 h += 1;
143 w += 1;
144 x = avg_pool2d(&mut b.m(), x, [4, 1], [4, 1], batch, 128, h, w);
145 let seq = w;
146 let x = b
147 .m()
148 .reshape_(x, vec![batch as i64, FEAT as i64, seq as i64]);
149 let x = b.m().transpose_(x, vec![2, 0, 1]);
150 Ok((x, seq))
151}
152
153pub fn build_recognition_conv_graph(
155 wm: &mut WeightMap,
156 cfg: RecognitionGraphConfig,
157) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
158 let mut b = OcrGraphBuilder::new("ocr_recognition_conv");
159 let batch = cfg.batch;
160 let h = RECOGNITION_HEIGHT;
161 let w = cfg.width;
162 let image = b
163 .m()
164 .input("image", Shape::new(&[batch, 1, h, w], DType::F32));
165 let (x, _seq) = build_recognition_conv_front(&mut b, wm, image, batch, h, w)?;
166 b.m().set_outputs(vec![x]);
167 b.finish()
168}
169
170pub fn build_recognition_after_g1_graph(
172 wm: &mut WeightMap,
173 cfg: RecognitionGraphConfig,
174) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
175 build_recognition_graph_inner(wm, cfg, Some(1))
176}
177
178pub fn build_recognition_after_g2_graph(
180 wm: &mut WeightMap,
181 cfg: RecognitionGraphConfig,
182) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
183 build_recognition_graph_inner(wm, cfg, Some(2))
184}
185
186pub fn build_recognition_after_logits_graph(
188 wm: &mut WeightMap,
189 cfg: RecognitionGraphConfig,
190) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
191 build_recognition_graph_inner(wm, cfg, Some(3))
192}
193
194pub fn build_recognition_graph(
195 wm: &mut WeightMap,
196 cfg: RecognitionGraphConfig,
197) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
198 build_recognition_graph_inner(wm, cfg, None)
199}
200
201fn build_recognition_graph_inner(
202 wm: &mut WeightMap,
203 cfg: RecognitionGraphConfig,
204 stop_after_gru: Option<u8>,
205) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
206 let mut b = OcrGraphBuilder::new("ocr_recognition");
207 let batch = cfg.batch;
208 let h = RECOGNITION_HEIGHT;
209 let w = cfg.width;
210
211 let image = b
212 .m()
213 .input("image", Shape::new(&[batch, 1, h, w], DType::F32));
214
215 let (x, seq) = build_recognition_conv_front(&mut b, wm, image, batch, h, w)?;
216
217 let _seq_lens = gru_seq_lens_param(&mut b, batch, seq)?;
221 let _init_h = gru_init_hidden_param(&mut b, batch, HIDDEN, 2)?;
222 let _w1 = b.load_param(wm, "onnx::GRU_422")?;
223 let _r1 = b.load_param(wm, "onnx::GRU_423")?;
224 let _b1 = b.load_param(wm, "onnx::GRU_421")?;
225
226 let pad = (2 * HIDDEN).saturating_sub(FEAT);
227 let g1 = if pad == 0 {
228 x
229 } else {
230 let key = format!("ocr.recognition.pad_{seq}_{batch}_{pad}");
231 let zeros = vec![0.0f32; seq * batch * pad];
232 let z = b
233 .m()
234 .param(&key, Shape::new(&[seq, batch, pad], DType::F32));
235 b.params.insert(key, zeros);
236 b.m().concat_(vec![x, z], 2)
237 };
238 if stop_after_gru == Some(1) {
239 b.m().set_outputs(vec![g1]);
240 return b.finish();
241 }
242
243 let _w2 = b.load_param(wm, "onnx::GRU_465")?;
244 let _r2 = b.load_param(wm, "onnx::GRU_466")?;
245 let _b2 = b.load_param(wm, "onnx::GRU_464")?;
246 let _init_h2 = gru_init_hidden_param(&mut b, batch, HIDDEN, 2)?;
247 let g2 = g1;
248 if stop_after_gru == Some(2) {
249 b.m().set_outputs(vec![g2]);
250 return b.finish();
251 }
252
253 let head_w = b.load_param(wm, "onnx::MatMul_467")?;
254 let head_b = b.load_param(wm, "output.0.bias")?;
255 let logits = b.m().mm(g2, head_w);
256 let logits = add_bias_seq(&mut b, logits, head_b, batch, seq, NUM_CLASSES)?;
257 if stop_after_gru == Some(3) {
258 b.m().set_outputs(vec![logits]);
259 return b.finish();
260 }
261 let out = b.m().transpose_(logits, vec![1, 0, 2]);
262 b.m().set_outputs(vec![out]);
263
264 assert_weights_drained(wm, "recognition graph")?;
265 b.finish()
266}
267
268fn conv_relu(
269 b: &mut OcrGraphBuilder,
270 wm: &mut WeightMap,
271 x: HirNodeId,
272 w_key: &str,
273 bias_key: &str,
274 batch: usize,
275 out_c: usize,
276 _in_c: usize,
277 h: usize,
278 w: usize,
279) -> Result<HirNodeId> {
280 let weight = b.load_param(wm, w_key)?;
281 let bias = b.load_param(wm, bias_key)?;
282 let y = conv2d_bias(
283 &mut b.m(),
284 x,
285 weight,
286 bias,
287 batch,
288 out_c,
289 3,
290 3,
291 [1, 1],
292 [1, 1],
293 h,
294 w,
295 );
296 Ok(b.m().relu(y))
297}
298
299fn fused_conv2x2(
301 b: &mut OcrGraphBuilder,
302 wm: &mut WeightMap,
303 x: HirNodeId,
304 w_key: &str,
305 bias_key: &str,
306 batch: usize,
307 out_c: usize,
308 _in_c: usize,
309 h: usize,
310 w: usize,
311) -> Result<HirNodeId> {
312 let weight = b.load_param(wm, w_key)?;
313 let bias = b.load_param(wm, bias_key)?;
314 let out_h = h + 1;
315 let out_w = w + 1;
316 Ok(conv2d_bias(
317 &mut b.m(),
318 x,
319 weight,
320 bias,
321 batch,
322 out_c,
323 2,
324 2,
325 [1, 1],
326 [1, 1],
327 out_h,
328 out_w,
329 ))
330}
331
332fn fused_conv_relu(
333 b: &mut OcrGraphBuilder,
334 wm: &mut WeightMap,
335 x: HirNodeId,
336 w_key: &str,
337 bias_key: &str,
338 batch: usize,
339 out_c: usize,
340 _in_c: usize,
341 h: usize,
342 w: usize,
343) -> Result<HirNodeId> {
344 let weight = b.load_param(wm, w_key)?;
345 let bias = b.load_param(wm, bias_key)?;
346 let y = conv2d_bias(
347 &mut b.m(),
348 x,
349 weight,
350 bias,
351 batch,
352 out_c,
353 3,
354 3,
355 [1, 1],
356 [1, 1],
357 h,
358 w,
359 );
360 Ok(b.m().relu(y))
361}
362
363fn pool_2x1(
364 g: &mut HirMut<'_>,
365 x: HirNodeId,
366 batch: usize,
367 c: usize,
368 h: usize,
369 w: usize,
370) -> HirNodeId {
371 use rlx_ir::op::{Op, ReduceOp};
372 let dt = g.shape(x).dtype();
373 let out_h = (h.saturating_sub(2)) / 2 + 1;
374 let out_w = w;
375 let out_shape = rlx_core::vision_ops_ir::nchw_shape(batch, c, out_h, out_w, dt);
376 g.add_node(
377 Op::Pool {
378 kind: ReduceOp::Max,
379 kernel_size: vec![2, 1],
380 stride: vec![2, 1],
381 padding: vec![0, 0],
382 },
383 vec![x],
384 out_shape,
385 )
386}
387
388fn gru_seq_lens_param(b: &mut OcrGraphBuilder, batch: usize, seq: usize) -> Result<HirNodeId> {
389 let key = format!("ocr.gru.seq_lens.{batch}x{seq}");
390 let data = vec![seq as f32; batch];
391 let id = b.m().param(&key, Shape::new(&[batch], DType::F32));
392 b.params.insert(key, data);
393 Ok(id)
394}
395
396fn gru_init_hidden_param(
397 b: &mut OcrGraphBuilder,
398 batch: usize,
399 hidden: usize,
400 num_directions: usize,
401) -> Result<HirNodeId> {
402 let key = format!("ocr.gru.init_h.{num_directions}x{batch}x{hidden}");
403 let n = num_directions * batch * hidden;
404 let id = b.m().param(
405 &key,
406 Shape::new(&[num_directions, batch, hidden], DType::F32),
407 );
408 b.params.insert(key, vec![0f32; n]);
409 Ok(id)
410}
411
412pub fn log_softmax_last_axis(data: &mut [f32], classes: usize) {
414 assert!(classes > 0 && data.len().is_multiple_of(classes));
415 for lane in data.chunks_mut(classes) {
416 let max_val = lane.iter().copied().fold(f32::NEG_INFINITY, f32::max);
417 let log_exp_sum = lane.iter().map(|&x| (x - max_val).exp()).sum::<f32>().ln();
418 for el in lane.iter_mut() {
419 *el = (*el - max_val) - log_exp_sum;
420 }
421 }
422}
423
424fn add_bias_seq(
425 b: &mut OcrGraphBuilder,
426 y: HirNodeId,
427 bias: HirNodeId,
428 _batch: usize,
429 _seq: usize,
430 classes: usize,
431) -> Result<HirNodeId> {
432 let bias3 = b.m().reshape_(bias, vec![1, 1, classes as i64]);
433 Ok(b.m().add(y, bias3))
434}