1use anyhow::{Context, Result, bail};
17use rlx_core::weight_map::WeightMap;
18use rlx_ir::hir::{HirMut, HirNodeId};
19use rlx_ir::{DType, Shape};
20use std::collections::HashMap;
21
22pub struct OcrGraphBuilder {
23 pub hir: rlx_ir::hir::HirModule,
24 pub params: HashMap<String, Vec<f32>>,
25 zero_bias: HashMap<usize, HirNodeId>,
26}
27
28impl OcrGraphBuilder {
29 pub fn new(name: &str) -> Self {
30 Self {
31 hir: rlx_ir::hir::HirModule::new(name),
32 params: HashMap::new(),
33 zero_bias: HashMap::new(),
34 }
35 }
36
37 pub fn m(&mut self) -> HirMut<'_> {
38 HirMut::new(&mut self.hir)
39 }
40
41 pub fn zero_bias(&mut self, channels: usize) -> Result<HirNodeId> {
42 if let Some(&id) = self.zero_bias.get(&channels) {
43 return Ok(id);
44 }
45 let key = format!("ocr.zero_bias.{channels}");
46 let data = vec![0f32; channels];
47 let id = self.m().param(&key, Shape::new(&[channels], DType::F32));
48 self.params.insert(key, data);
49 self.zero_bias.insert(channels, id);
50 Ok(id)
51 }
52
53 pub fn load_param(&mut self, wm: &mut WeightMap, key: &str) -> Result<HirNodeId> {
54 let (data, shape) = wm
55 .take(key)
56 .with_context(|| format!("missing weight {key}"))?;
57 let id = self.m().param(key, Shape::new(&shape, DType::F32));
58 self.params.insert(key.to_string(), data);
59 Ok(id)
60 }
61
62 pub fn load_param_optional(
63 &mut self,
64 wm: &mut WeightMap,
65 key: &str,
66 ) -> Result<Option<HirNodeId>> {
67 if !wm.has(key) {
68 return Ok(None);
69 }
70 Ok(Some(self.load_param(wm, key)?))
71 }
72
73 pub fn finish(self) -> Result<(rlx_ir::Graph, HashMap<String, Vec<f32>>)> {
74 rlx_core::flow_util::graph_from_hir(self.hir, self.params)
75 }
76}
77
78pub const DET_ONNX_PW: [(&str, &str); 26] = [
80 ("onnx::Conv_470", "onnx::Conv_471"),
81 ("onnx::Conv_473", "onnx::Conv_474"),
82 ("onnx::Conv_476", "onnx::Conv_477"),
83 ("onnx::Conv_479", "onnx::Conv_480"),
84 ("onnx::Conv_482", "onnx::Conv_483"),
85 ("onnx::Conv_485", "onnx::Conv_486"),
86 ("onnx::Conv_488", "onnx::Conv_489"),
87 ("onnx::Conv_491", "onnx::Conv_492"),
88 ("onnx::Conv_494", "onnx::Conv_495"),
89 ("onnx::Conv_497", "onnx::Conv_498"),
90 ("onnx::Conv_500", "onnx::Conv_501"),
91 ("onnx::Conv_503", "onnx::Conv_504"),
92 ("onnx::Conv_506", "onnx::Conv_507"),
93 ("onnx::Conv_509", "onnx::Conv_510"),
94 ("onnx::Conv_512", "onnx::Conv_513"),
95 ("onnx::Conv_515", "onnx::Conv_516"),
96 ("onnx::Conv_518", "onnx::Conv_519"),
97 ("onnx::Conv_521", "onnx::Conv_522"),
98 ("onnx::Conv_524", "onnx::Conv_525"),
99 ("onnx::Conv_527", "onnx::Conv_528"),
100 ("onnx::Conv_530", "onnx::Conv_531"),
101 ("onnx::Conv_533", "onnx::Conv_534"),
102 ("onnx::Conv_536", "onnx::Conv_537"),
103 ("onnx::Conv_539", "onnx::Conv_540"),
104 ("onnx::Conv_542", "onnx::Conv_543"),
105 ("onnx::Conv_545", "onnx::Conv_546"),
106];
107
108pub const DET_DW_KEYS: [&str; 26] = [
109 "in_conv.seq.0.seq.0.weight",
110 "in_conv.seq.1.seq.0.weight",
111 "down.0.seq.0.seq.0.seq.0.weight",
112 "down.0.seq.0.seq.1.seq.0.weight",
113 "down.1.seq.0.seq.0.seq.0.weight",
114 "down.1.seq.0.seq.1.seq.0.weight",
115 "down.2.seq.0.seq.0.seq.0.weight",
116 "down.2.seq.0.seq.1.seq.0.weight",
117 "down.3.seq.0.seq.0.seq.0.weight",
118 "down.3.seq.0.seq.1.seq.0.weight",
119 "down.4.seq.0.seq.0.seq.0.weight",
120 "down.4.seq.0.seq.1.seq.0.weight",
121 "down.5.seq.0.seq.0.seq.0.weight",
122 "down.5.seq.0.seq.1.seq.0.weight",
123 "up.5.contract.seq.0.seq.0.weight",
124 "up.5.contract.seq.1.seq.0.weight",
125 "up.4.contract.seq.0.seq.0.weight",
126 "up.4.contract.seq.1.seq.0.weight",
127 "up.3.contract.seq.0.seq.0.weight",
128 "up.3.contract.seq.1.seq.0.weight",
129 "up.2.contract.seq.0.seq.0.weight",
130 "up.2.contract.seq.1.seq.0.weight",
131 "up.1.contract.seq.0.seq.0.weight",
132 "up.1.contract.seq.1.seq.0.weight",
133 "up.0.contract.seq.0.seq.0.weight",
134 "up.0.contract.seq.1.seq.0.weight",
135];
136
137pub fn detection_input_hw() -> (usize, usize) {
138 if let Ok(s) = std::env::var("OCR_DETECTION_HW") {
139 if let Some(hw) = parse_hw(&s) {
140 return hw;
141 }
142 }
143 (800, 600)
144}
145
146pub fn parse_hw(s: &str) -> Option<(usize, usize)> {
147 let (h, w) = s.split_once(',')?;
148 Some((h.trim().parse().ok()?, w.trim().parse().ok()?))
149}
150
151pub fn assert_weights_drained(wm: &WeightMap, context: &str) -> Result<()> {
152 let leftover: Vec<_> = wm
153 .keys()
154 .filter(|k| !k.starts_with('/') && !k.contains("Constant") && !k.contains("Unsqueeze"))
155 .collect();
156 if leftover.is_empty() {
157 return Ok(());
158 }
159 let mut keys = leftover;
160 keys.sort();
161 bail!("{context}: unmapped weights: {keys:?}");
162}