Skip to main content

rlx_ocr/rlx/
detection.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::capabilities::validate_device;
17use crate::config::DetectionParams;
18use crate::detection::postprocess::word_rects_from_mask;
19use crate::model::{DetectionGraphConfig, build_detection_graph};
20use crate::preprocess::BLACK_VALUE;
21use crate::weights::{
22    HF_DETECTION_ST, HF_DETECTION_ST_FULL, SafetensorsFile, prefer_safetensors_path,
23};
24use anyhow::{Result, anyhow};
25use rlx_core::flow_bridge::compile_options_for_profile;
26use rlx_core::flow_util::attach_built_params;
27use rlx_flow::CompileProfile;
28use rlx_runtime::{CompiledGraph, Device, Session};
29#[cfg(feature = "tensor-ops")]
30use rten::{FloatOperators, Operators};
31use rten_imageproc::RotatedRect;
32use rten_tensor::prelude::*;
33use rten_tensor::{NdTensor, NdTensorView};
34use std::path::Path;
35use std::sync::Mutex;
36
37/// Text detector using a compiled native RLX U-Net graph.
38pub struct RlxTextDetector {
39    compiled: Mutex<CompiledGraph>,
40    params: DetectionParams,
41    input_hw: (usize, usize),
42    #[allow(dead_code)]
43    device: Device,
44}
45
46impl RlxTextDetector {
47    pub fn from_path(
48        path: impl AsRef<Path>,
49        params: DetectionParams,
50        device: Device,
51    ) -> Result<Self> {
52        Self::from_safetensors(path.as_ref(), params, device)
53    }
54
55    pub fn from_safetensors(path: &Path, params: DetectionParams, device: Device) -> Result<Self> {
56        Self::from_safetensors_sized(path, params, DetectionGraphConfig::default(), device)
57    }
58
59    pub fn from_safetensors_sized(
60        path: &Path,
61        params: DetectionParams,
62        cfg: DetectionGraphConfig,
63        device: Device,
64    ) -> Result<Self> {
65        validate_device(device)?;
66        let mut wm = SafetensorsFile::open(path)?.weight_map()?;
67        let input_hw = (cfg.height, cfg.width);
68        let (graph, param_map) = build_detection_graph(&mut wm, cfg)?;
69        let opts = compile_options_for_profile(&CompileProfile::encoder(), device);
70        let mut compiled = Session::new(device).compile_with(graph, &opts);
71        attach_built_params(&mut compiled, param_map, &[]);
72        Ok(Self {
73            compiled: Mutex::new(compiled),
74            params,
75            input_hw,
76            device,
77        })
78    }
79
80    pub fn from_model_dir(dir: &Path, params: DetectionParams, device: Device) -> Result<Self> {
81        let path =
82            prefer_safetensors_path(dir, crate::weights::HF_DETECTION_ST, HF_DETECTION_ST_FULL);
83        if !path.is_file() {
84            let _fallback = dir.join(HF_DETECTION_ST);
85            anyhow::bail!(
86                "missing detection safetensors in {dir:?} (need ocr-detection-full.safetensors); \
87                 run `rlx-ocr-convert` on {:?}",
88                dir.join("text-detection-ssfbcj81.rten")
89            );
90        }
91        Self::from_safetensors(&path, params, device)
92    }
93
94    pub fn fixed_input_hw(&self) -> Option<(usize, usize)> {
95        Some(self.input_hw)
96    }
97
98    pub fn detect_words(&self, image: NdTensorView<f32, 3>) -> Result<Vec<RotatedRect>> {
99        let mask = self.detect_text_pixels(image)?;
100        Ok(word_rects_from_mask(
101            mask.view(),
102            self.params.text_threshold,
103            self.params.min_area,
104        ))
105    }
106
107    pub fn detect_text_pixels(&self, image: NdTensorView<f32, 3>) -> Result<NdTensor<f32, 2>> {
108        let [img_chans, img_height, img_width] = image.shape();
109        let image = image.reshaped([1, img_chans, img_height, img_width]);
110        let (in_height, in_width) = self.input_hw;
111
112        let pad_bottom = (in_height as i32 - img_height as i32).max(0);
113        let pad_right = (in_width as i32 - img_width as i32).max(0);
114        let image = (pad_bottom > 0 || pad_right > 0)
115            .then(|| {
116                let pads = &[0, 0, 0, 0, 0, 0, pad_bottom, pad_right];
117                image.pad(pads.into(), BLACK_VALUE)
118            })
119            .transpose()?
120            .map(|t| t.into_cow())
121            .unwrap_or(image.as_dyn().as_cow());
122
123        let image = (image.size(2) != in_height || image.size(3) != in_width)
124            .then(|| image.resize_image([in_height, in_width]))
125            .transpose()?
126            .map(|t| t.into_cow())
127            .unwrap_or(image);
128
129        let mut compiled = self.compiled.lock().map_err(|_| anyhow!("lock poisoned"))?;
130        let input: Vec<f32> = image.iter().copied().collect();
131        let outputs = compiled.run(&[("image", input.as_slice())]);
132        let flat = outputs
133            .into_iter()
134            .next()
135            .ok_or_else(|| anyhow!("detection returned no output"))?;
136
137        let valid_h = in_height - pad_bottom as usize;
138        let valid_w = in_width - pad_right as usize;
139        let mask = NdTensor::from_data([1, 1, in_height, in_width], flat);
140        // Keep NCHW rank: `slice((0, 0, ..))` squeezes batch/channel and breaks `resize_image`.
141        let mask = mask
142            .slice((.., .., ..valid_h, ..valid_w))
143            .resize_image([img_height, img_width])?;
144        Ok(mask.into_shape([img_height, img_width]))
145    }
146}