1use crate::capabilities::validate_device;
17use crate::config::DetectionParams;
18use crate::detection::postprocess::word_rects_from_mask;
19use crate::model::{DetectionGraphConfig, build_detection_graph};
20use crate::preprocess::BLACK_VALUE;
21use crate::weights::{
22 HF_DETECTION_ST, HF_DETECTION_ST_FULL, SafetensorsFile, prefer_safetensors_path,
23};
24use anyhow::{Result, anyhow};
25use rlx_core::flow_bridge::compile_options_for_profile;
26use rlx_core::flow_util::attach_built_params;
27use rlx_flow::CompileProfile;
28use rlx_runtime::{CompiledGraph, Device, Session};
29#[cfg(feature = "tensor-ops")]
30use rten::{FloatOperators, Operators};
31use rten_imageproc::RotatedRect;
32use rten_tensor::prelude::*;
33use rten_tensor::{NdTensor, NdTensorView};
34use std::path::Path;
35use std::sync::Mutex;
36
37pub struct RlxTextDetector {
39 compiled: Mutex<CompiledGraph>,
40 params: DetectionParams,
41 input_hw: (usize, usize),
42 #[allow(dead_code)]
43 device: Device,
44}
45
46impl RlxTextDetector {
47 pub fn from_path(
48 path: impl AsRef<Path>,
49 params: DetectionParams,
50 device: Device,
51 ) -> Result<Self> {
52 Self::from_safetensors(path.as_ref(), params, device)
53 }
54
55 pub fn from_safetensors(path: &Path, params: DetectionParams, device: Device) -> Result<Self> {
56 Self::from_safetensors_sized(path, params, DetectionGraphConfig::default(), device)
57 }
58
59 pub fn from_safetensors_sized(
60 path: &Path,
61 params: DetectionParams,
62 cfg: DetectionGraphConfig,
63 device: Device,
64 ) -> Result<Self> {
65 validate_device(device)?;
66 let mut wm = SafetensorsFile::open(path)?.weight_map()?;
67 let input_hw = (cfg.height, cfg.width);
68 let (graph, param_map) = build_detection_graph(&mut wm, cfg)?;
69 let opts = compile_options_for_profile(&CompileProfile::encoder(), device);
70 let mut compiled = Session::new(device).compile_with(graph, &opts);
71 attach_built_params(&mut compiled, param_map, &[]);
72 Ok(Self {
73 compiled: Mutex::new(compiled),
74 params,
75 input_hw,
76 device,
77 })
78 }
79
80 pub fn from_model_dir(dir: &Path, params: DetectionParams, device: Device) -> Result<Self> {
81 let path =
82 prefer_safetensors_path(dir, crate::weights::HF_DETECTION_ST, HF_DETECTION_ST_FULL);
83 if !path.is_file() {
84 let _fallback = dir.join(HF_DETECTION_ST);
85 anyhow::bail!(
86 "missing detection safetensors in {dir:?} (need ocr-detection-full.safetensors); \
87 run `rlx-ocr-convert` on {:?}",
88 dir.join("text-detection-ssfbcj81.rten")
89 );
90 }
91 Self::from_safetensors(&path, params, device)
92 }
93
94 pub fn fixed_input_hw(&self) -> Option<(usize, usize)> {
95 Some(self.input_hw)
96 }
97
98 pub fn detect_words(&self, image: NdTensorView<f32, 3>) -> Result<Vec<RotatedRect>> {
99 let mask = self.detect_text_pixels(image)?;
100 Ok(word_rects_from_mask(
101 mask.view(),
102 self.params.text_threshold,
103 self.params.min_area,
104 ))
105 }
106
107 pub fn detect_text_pixels(&self, image: NdTensorView<f32, 3>) -> Result<NdTensor<f32, 2>> {
108 let [img_chans, img_height, img_width] = image.shape();
109 let image = image.reshaped([1, img_chans, img_height, img_width]);
110 let (in_height, in_width) = self.input_hw;
111
112 let pad_bottom = (in_height as i32 - img_height as i32).max(0);
113 let pad_right = (in_width as i32 - img_width as i32).max(0);
114 let image = (pad_bottom > 0 || pad_right > 0)
115 .then(|| {
116 let pads = &[0, 0, 0, 0, 0, 0, pad_bottom, pad_right];
117 image.pad(pads.into(), BLACK_VALUE)
118 })
119 .transpose()?
120 .map(|t| t.into_cow())
121 .unwrap_or(image.as_dyn().as_cow());
122
123 let image = (image.size(2) != in_height || image.size(3) != in_width)
124 .then(|| image.resize_image([in_height, in_width]))
125 .transpose()?
126 .map(|t| t.into_cow())
127 .unwrap_or(image);
128
129 let mut compiled = self.compiled.lock().map_err(|_| anyhow!("lock poisoned"))?;
130 let input: Vec<f32> = image.iter().copied().collect();
131 let outputs = compiled.run(&[("image", input.as_slice())]);
132 let flat = outputs
133 .into_iter()
134 .next()
135 .ok_or_else(|| anyhow!("detection returned no output"))?;
136
137 let valid_h = in_height - pad_bottom as usize;
138 let valid_w = in_width - pad_right as usize;
139 let mask = NdTensor::from_data([1, 1, in_height, in_width], flat);
140 let mask = mask
142 .slice((.., .., ..valid_h, ..valid_w))
143 .resize_image([img_height, img_width])?;
144 Ok(mask.into_shape([img_height, img_width]))
145 }
146}