use super::weights::{
DET_DW_KEYS, DET_ONNX_PW, OcrGraphBuilder, assert_weights_drained, detection_input_hw,
};
use anyhow::Result;
use rlx_core::vision_ops_ir::{
conv_transpose2d_k3s2_bias_trim, conv2d_bias, conv2d_bias_groups, max_pool2d_2x2, sigmoid_nchw,
};
use rlx_core::weight_map::WeightMap;
use rlx_ir::hir::HirNodeId;
use rlx_ir::{DType, HirGraphExt, Shape};
#[allow(dead_code)]
pub const DEFAULT_DETECTION_INPUT_HW: (usize, usize) = (800, 600);
#[derive(Clone, Copy, Debug)]
pub struct DetectionGraphConfig {
pub batch: usize,
pub height: usize,
pub width: usize,
}
impl Default for DetectionGraphConfig {
fn default() -> Self {
let (height, width) = detection_input_hw();
Self {
batch: 1,
height,
width,
}
}
}
const DEPTH_SCALE: [usize; 7] = [8, 16, 32, 32, 64, 128, 256];
pub fn build_detection_graph(
wm: &mut WeightMap,
cfg: DetectionGraphConfig,
) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
let mut b = OcrGraphBuilder::new("ocr_detection");
let f = DType::F32;
let batch = cfg.batch;
let mut h = cfg.height;
let mut w = cfg.width;
let image = b.m().input("image", Shape::new(&[batch, 1, h, w], f));
let mut block = 0usize;
let mut x = double_conv(
&mut b,
wm,
image,
&mut block,
1,
DEPTH_SCALE[0],
batch,
h,
w,
)?;
let in_conv_skip = (x, h, w);
let mut x_down: Vec<(HirNodeId, usize, usize)> = Vec::new();
for level in 0..DEPTH_SCALE.len() - 1 {
let in_c = DEPTH_SCALE[level];
let out_c = DEPTH_SCALE[level + 1];
x = double_conv(&mut b, wm, x, &mut block, in_c, out_c, batch, h, w)?;
x = max_pool2d_2x2(&mut b.m(), x, batch, out_c, h, w);
h /= 2;
w /= 2;
x_down.push((x, h, w));
}
let mut x_up = x;
let mut up_h = h;
let mut up_w = w;
for up_idx in (0..DEPTH_SCALE.len() - 1).rev() {
let out_c = DEPTH_SCALE[up_idx];
let cross_c = DEPTH_SCALE[up_idx];
let (skip, skip_h, skip_w) = if up_idx == 0 {
(in_conv_skip.0, in_conv_skip.1, in_conv_skip.2)
} else {
let (skip_node, sh, sw) = x_down[up_idx - 1];
(skip_node, sh, sw)
};
let up_w_key = format!("up.{up_idx}.up.weight");
let up_b_key = format!("up.{up_idx}.up.bias");
let up_weight = b.load_param(wm, &up_w_key)?;
let up_bias = b.load_param(wm, &up_b_key)?;
let upscaled = conv_transpose2d_k3s2_bias_trim(
&mut b.m(),
x_up,
up_weight,
up_bias,
batch,
out_c,
up_h,
up_w,
skip_h,
skip_w,
);
up_h = skip_h;
up_w = skip_w;
let cat = b.m().concat_(vec![upscaled, skip], 1);
x_up = double_conv(
&mut b,
wm,
cat,
&mut block,
out_c + cross_c,
out_c,
batch,
up_h,
up_w,
)?;
}
let out_w = b.load_param(wm, "out_conv.0.weight")?;
let out_b = b.load_param(wm, "out_conv.0.bias")?;
let logits = conv2d_bias(
&mut b.m(),
x_up,
out_w,
out_b,
batch,
1,
1,
1,
[1, 1],
[0, 0],
up_h,
up_w,
);
let mask = sigmoid_nchw(&mut b.m(), logits);
b.m().set_outputs(vec![mask]);
assert_weights_drained(wm, "detection graph")?;
b.finish()
}
fn double_conv(
b: &mut OcrGraphBuilder,
wm: &mut WeightMap,
mut x: HirNodeId,
block: &mut usize,
in_c: usize,
out_c: usize,
batch: usize,
h: usize,
w: usize,
) -> Result<HirNodeId> {
let (pw0_w, pw0_b) = DET_ONNX_PW[*block];
x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], in_c, batch, h, w)?;
x = pointwise_relu(b, wm, x, pw0_w, pw0_b, in_c, out_c, batch, h, w)?;
*block += 1;
let (pw1_w, pw1_b) = DET_ONNX_PW[*block];
x = depthwise_conv(b, wm, x, DET_DW_KEYS[*block], out_c, batch, h, w)?;
x = pointwise_relu(b, wm, x, pw1_w, pw1_b, out_c, out_c, batch, h, w)?;
*block += 1;
Ok(x)
}
fn depthwise_conv(
b: &mut OcrGraphBuilder,
wm: &mut WeightMap,
x: HirNodeId,
dw_key: &str,
channels: usize,
batch: usize,
h: usize,
w: usize,
) -> Result<HirNodeId> {
let weight = b.load_param(wm, dw_key)?;
let bias = b.zero_bias(channels)?;
Ok(conv2d_bias_groups(
&mut b.m(),
x,
weight,
bias,
batch,
channels,
3,
3,
[1, 1],
[1, 1],
channels,
h,
w,
))
}
fn pointwise_relu(
b: &mut OcrGraphBuilder,
wm: &mut WeightMap,
x: HirNodeId,
w_key: &str,
b_key: &str,
_in_c: usize,
out_c: usize,
batch: usize,
h: usize,
w: usize,
) -> Result<HirNodeId> {
let weight = b.load_param(wm, w_key)?;
let bias = b.load_param(wm, b_key)?;
let y = conv2d_bias(
&mut b.m(),
x,
weight,
bias,
batch,
out_c,
1,
1,
[1, 1],
[0, 0],
h,
w,
);
Ok(b.m().relu(y))
}