Skip to main content

rlx_ocr/model/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Context, Result, bail};
17use rlx_core::weight_map::WeightMap;
18use rlx_ir::hir::{HirMut, HirNodeId};
19use rlx_ir::{DType, Shape};
20use std::collections::HashMap;
21
22pub struct OcrGraphBuilder {
23    pub hir: rlx_ir::hir::HirModule,
24    pub params: HashMap<String, Vec<f32>>,
25    zero_bias: HashMap<usize, HirNodeId>,
26}
27
28impl OcrGraphBuilder {
29    pub fn new(name: &str) -> Self {
30        Self {
31            hir: rlx_ir::hir::HirModule::new(name),
32            params: HashMap::new(),
33            zero_bias: HashMap::new(),
34        }
35    }
36
37    pub fn m(&mut self) -> HirMut<'_> {
38        HirMut::new(&mut self.hir)
39    }
40
41    pub fn zero_bias(&mut self, channels: usize) -> Result<HirNodeId> {
42        if let Some(&id) = self.zero_bias.get(&channels) {
43            return Ok(id);
44        }
45        let key = format!("ocr.zero_bias.{channels}");
46        let data = vec![0f32; channels];
47        let id = self.m().param(&key, Shape::new(&[channels], DType::F32));
48        self.params.insert(key, data);
49        self.zero_bias.insert(channels, id);
50        Ok(id)
51    }
52
53    pub fn load_param(&mut self, wm: &mut WeightMap, key: &str) -> Result<HirNodeId> {
54        let (data, shape) = wm
55            .take(key)
56            .with_context(|| format!("missing weight {key}"))?;
57        let id = self.m().param(key, Shape::new(&shape, DType::F32));
58        self.params.insert(key.to_string(), data);
59        Ok(id)
60    }
61
62    pub fn load_param_optional(
63        &mut self,
64        wm: &mut WeightMap,
65        key: &str,
66    ) -> Result<Option<HirNodeId>> {
67        if !wm.has(key) {
68            return Ok(None);
69        }
70        Ok(Some(self.load_param(wm, key)?))
71    }
72
73    pub fn finish(self) -> Result<(rlx_ir::Graph, HashMap<String, Vec<f32>>)> {
74        rlx_core::flow_util::graph_from_hir(self.hir, self.params)
75    }
76}
77
78/// 26 fused pointwise+BN blocks in traversal order (13 `DoubleConv` × 2).
79pub const DET_ONNX_PW: [(&str, &str); 26] = [
80    ("onnx::Conv_470", "onnx::Conv_471"),
81    ("onnx::Conv_473", "onnx::Conv_474"),
82    ("onnx::Conv_476", "onnx::Conv_477"),
83    ("onnx::Conv_479", "onnx::Conv_480"),
84    ("onnx::Conv_482", "onnx::Conv_483"),
85    ("onnx::Conv_485", "onnx::Conv_486"),
86    ("onnx::Conv_488", "onnx::Conv_489"),
87    ("onnx::Conv_491", "onnx::Conv_492"),
88    ("onnx::Conv_494", "onnx::Conv_495"),
89    ("onnx::Conv_497", "onnx::Conv_498"),
90    ("onnx::Conv_500", "onnx::Conv_501"),
91    ("onnx::Conv_503", "onnx::Conv_504"),
92    ("onnx::Conv_506", "onnx::Conv_507"),
93    ("onnx::Conv_509", "onnx::Conv_510"),
94    ("onnx::Conv_512", "onnx::Conv_513"),
95    ("onnx::Conv_515", "onnx::Conv_516"),
96    ("onnx::Conv_518", "onnx::Conv_519"),
97    ("onnx::Conv_521", "onnx::Conv_522"),
98    ("onnx::Conv_524", "onnx::Conv_525"),
99    ("onnx::Conv_527", "onnx::Conv_528"),
100    ("onnx::Conv_530", "onnx::Conv_531"),
101    ("onnx::Conv_533", "onnx::Conv_534"),
102    ("onnx::Conv_536", "onnx::Conv_537"),
103    ("onnx::Conv_539", "onnx::Conv_540"),
104    ("onnx::Conv_542", "onnx::Conv_543"),
105    ("onnx::Conv_545", "onnx::Conv_546"),
106];
107
108pub const DET_DW_KEYS: [&str; 26] = [
109    "in_conv.seq.0.seq.0.weight",
110    "in_conv.seq.1.seq.0.weight",
111    "down.0.seq.0.seq.0.seq.0.weight",
112    "down.0.seq.0.seq.1.seq.0.weight",
113    "down.1.seq.0.seq.0.seq.0.weight",
114    "down.1.seq.0.seq.1.seq.0.weight",
115    "down.2.seq.0.seq.0.seq.0.weight",
116    "down.2.seq.0.seq.1.seq.0.weight",
117    "down.3.seq.0.seq.0.seq.0.weight",
118    "down.3.seq.0.seq.1.seq.0.weight",
119    "down.4.seq.0.seq.0.seq.0.weight",
120    "down.4.seq.0.seq.1.seq.0.weight",
121    "down.5.seq.0.seq.0.seq.0.weight",
122    "down.5.seq.0.seq.1.seq.0.weight",
123    "up.5.contract.seq.0.seq.0.weight",
124    "up.5.contract.seq.1.seq.0.weight",
125    "up.4.contract.seq.0.seq.0.weight",
126    "up.4.contract.seq.1.seq.0.weight",
127    "up.3.contract.seq.0.seq.0.weight",
128    "up.3.contract.seq.1.seq.0.weight",
129    "up.2.contract.seq.0.seq.0.weight",
130    "up.2.contract.seq.1.seq.0.weight",
131    "up.1.contract.seq.0.seq.0.weight",
132    "up.1.contract.seq.1.seq.0.weight",
133    "up.0.contract.seq.0.seq.0.weight",
134    "up.0.contract.seq.1.seq.0.weight",
135];
136
137pub fn detection_input_hw() -> (usize, usize) {
138    if let Ok(s) = std::env::var("OCR_DETECTION_HW") {
139        if let Some(hw) = parse_hw(&s) {
140            return hw;
141        }
142    }
143    (800, 600)
144}
145
146pub fn parse_hw(s: &str) -> Option<(usize, usize)> {
147    let (h, w) = s.split_once(',')?;
148    Some((h.trim().parse().ok()?, w.trim().parse().ok()?))
149}
150
151pub fn assert_weights_drained(wm: &WeightMap, context: &str) -> Result<()> {
152    let leftover: Vec<_> = wm
153        .keys()
154        .filter(|k| !k.starts_with('/') && !k.contains("Constant") && !k.contains("Unsqueeze"))
155        .collect();
156    if leftover.is_empty() {
157        return Ok(());
158    }
159    let mut keys = leftover;
160    keys.sort();
161    bail!("{context}: unmapped weights: {keys:?}");
162}