1use crate::ctc::{
2 CtcGreedyDecoder, CtcGreedyDecoderConfig, CtcGreedyDecoderError, DecodedSequence,
3};
4use crate::dictionary::RecDictionary;
5use crate::preprocessing::PreprocessedRecBatch;
6use ndarray::Array3;
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::path::Path;
10use std::sync::Arc;
11use tract_onnx::prelude::*;
12use tract_onnx::tract_core::anyhow::anyhow;
13
14#[derive(Debug, Clone)]
16pub struct RecInferenceOutput {
17 pub logits: Array3<f32>,
18 pub valid_timesteps: Vec<usize>,
19}
20
21#[derive(Debug)]
23pub struct RecInferenceSession {
24 base_model: InferenceModel,
25 cache: RefCell<HashMap<(usize, u32), Arc<TypedRunnableModel<TypedModel>>>>,
26}
27
28impl RecInferenceSession {
29 pub fn load(model_path: impl AsRef<Path>) -> TractResult<Self> {
30 let model_path = model_path.as_ref();
31 println!("[RecInfer] Loading recognition model from {:?}", model_path);
32
33 let mut inference_model = tract_onnx::onnx()
34 .with_ignore_output_shapes(true)
35 .model_for_path(model_path)?;
36
37 let batch = inference_model.symbol_table.sym("batch");
38 let width = inference_model.symbol_table.sym("width");
39 inference_model.set_input_fact(
40 0,
41 InferenceFact::dt_shape(
42 f32::datum_type(),
43 tvec![batch.into(), TDim::from(3), TDim::from(48), width.into()],
44 ),
45 )?;
46
47 println!("[RecInfer] Recognition model prepared");
48 Ok(Self {
49 base_model: inference_model,
50 cache: RefCell::new(HashMap::new()),
51 })
52 }
53
54 pub fn run(&self, batch: &PreprocessedRecBatch) -> TractResult<RecInferenceOutput> {
55 let tensor_shape = batch.tensor.shape();
56 if tensor_shape.len() != 4 {
57 return Err(anyhow!(
58 "expected recognition input tensor to have 4 dimensions, got {:?}",
59 tensor_shape
60 )
61 .into());
62 }
63
64 let batch_size = tensor_shape[0];
65 let channel = tensor_shape[1];
66 let height = tensor_shape[2];
67 let width = tensor_shape[3];
68
69 println!(
70 "[RecInfer] Running inference with input shape {:?}",
71 tensor_shape
72 );
73
74 if channel != 3 || height != 48 {
75 return Err(anyhow!(
76 "expected recognition input to have shape [*, 3, 48, *], got {:?}",
77 tensor_shape
78 )
79 .into());
80 }
81
82 let plan = self.runnable_for_dims(batch_size, width as u32)?;
83 let outputs = plan.run(tvec!(batch.tensor.clone().into()))?;
84 let output_tensor = outputs
85 .into_iter()
86 .next()
87 .ok_or_else(|| anyhow!("SVTR model did not return any outputs"))?;
88
89 let view = output_tensor.to_array_view::<f32>()?;
90 if view.ndim() != 3 {
91 return Err(anyhow!(
92 "expected recognition output to have 3 dimensions, got {:?}",
93 view.shape()
94 )
95 .into());
96 }
97
98 let logits = view.into_dimensionality::<ndarray::Ix3>()?.to_owned();
99 let (logit_batch, time_steps, _classes) = logits.dim();
100 if logit_batch != batch_size {
101 return Err(anyhow!(
102 "batch dimension mismatch between input ({}) and output ({})",
103 batch_size,
104 logit_batch
105 )
106 .into());
107 }
108
109 let max_width = batch.max_width as f32;
110 let scale = if max_width > 0.0 {
111 time_steps as f32 / max_width
112 } else {
113 0.0
114 };
115 let valid_timesteps = batch
116 .valid_widths
117 .iter()
118 .map(|width| {
119 let mut steps = if scale > 0.0 {
120 (scale * *width as f32).round() as isize
121 } else {
122 time_steps as isize
123 };
124 if steps < 1 {
125 steps = 1;
126 }
127 if steps as usize > time_steps {
128 steps = time_steps as isize;
129 }
130 steps as usize
131 })
132 .collect::<Vec<_>>();
133
134 Ok(RecInferenceOutput {
135 logits,
136 valid_timesteps,
137 })
138 }
139
140 fn runnable_for_dims(
141 &self,
142 batch_size: usize,
143 width: u32,
144 ) -> TractResult<Arc<TypedRunnableModel<TypedModel>>> {
145 if let Some(plan) = self.cache.borrow().get(&(batch_size, width)) {
146 return Ok(Arc::clone(plan));
147 }
148
149 println!(
150 "[RecInfer] Preparing runnable model for batch {} width {}",
151 batch_size, width
152 );
153
154 let mut model = self.base_model.clone();
155 model.set_input_fact(
156 0,
157 InferenceFact::dt_shape(
158 f32::datum_type(),
159 tvec![
160 TDim::from(batch_size as i64),
161 TDim::from(3),
162 TDim::from(48),
163 TDim::from(width as i64)
164 ],
165 ),
166 )?;
167
168 let plan = model
169 .into_typed()?
170 .into_decluttered()?
171 .into_optimized()?
172 .into_runnable()?;
173
174 let plan = Arc::new(plan);
175 self.cache
176 .borrow_mut()
177 .insert((batch_size, width), Arc::clone(&plan));
178
179 Ok(plan)
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct RecPostProcessorConfig {
186 pub blank_id: usize,
187 pub fallback_token: String,
188}
189
190impl Default for RecPostProcessorConfig {
191 fn default() -> Self {
192 Self {
193 blank_id: 0,
194 fallback_token: "[UNK]".to_string(),
195 }
196 }
197}
198
199#[derive(Debug)]
201pub enum RecPostProcessorError {
202 Decoder(CtcGreedyDecoderError),
203}
204
205impl std::fmt::Display for RecPostProcessorError {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 RecPostProcessorError::Decoder(err) => write!(f, "ctc decoder failed: {}", err),
209 }
210 }
211}
212
213impl std::error::Error for RecPostProcessorError {
214 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215 match self {
216 RecPostProcessorError::Decoder(err) => Some(err),
217 }
218 }
219}
220
221impl From<CtcGreedyDecoderError> for RecPostProcessorError {
222 fn from(value: CtcGreedyDecoderError) -> Self {
223 RecPostProcessorError::Decoder(value)
224 }
225}
226
227#[derive(Debug, Clone)]
229pub struct RecPostProcessor {
230 decoder: CtcGreedyDecoder,
231 dictionary: Arc<RecDictionary>,
232}
233
234impl RecPostProcessor {
235 pub fn new(dictionary: Arc<RecDictionary>, config: RecPostProcessorConfig) -> Self {
236 let decoder = CtcGreedyDecoder::new(CtcGreedyDecoderConfig {
237 blank_id: config.blank_id,
238 fallback_token: Some(config.fallback_token),
239 });
240 Self {
241 decoder,
242 dictionary,
243 }
244 }
245
246 pub fn process(
247 &self,
248 output: &RecInferenceOutput,
249 ) -> Result<Vec<DecodedSequence>, RecPostProcessorError> {
250 self.decoder
251 .decode(&output.logits, &output.valid_timesteps, &self.dictionary)
252 .map_err(RecPostProcessorError::from)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::dictionary::RecDictionary;
260 use crate::preprocessing::{RecPreProcessor, RecPreProcessorConfig, RecTextRegion};
261 use image::{DynamicImage, ImageBuffer, Rgb};
262 use ndarray::Array3;
263 use std::env;
264 use std::fs;
265 use std::path::{Path, PathBuf};
266 use std::time::{SystemTime, UNIX_EPOCH};
267
268 fn gradient_image(width: u32, height: u32) -> DynamicImage {
269 let mut buffer = ImageBuffer::new(width, height);
270 for (x, y, pixel) in buffer.enumerate_pixels_mut() {
271 let base = ((x + y) % 256) as u8;
272 let green = base.saturating_add(16);
273 let blue = base.saturating_add(32);
274 *pixel = Rgb([base, green, blue]);
275 }
276 DynamicImage::ImageRgb8(buffer)
277 }
278
279 fn dictionary_from_tokens(tokens: &[&str]) -> RecDictionary {
280 let timestamp = SystemTime::now()
281 .duration_since(UNIX_EPOCH)
282 .unwrap()
283 .as_nanos();
284 let path = std::env::temp_dir().join(format!("rec_post_dict_{}.txt", timestamp));
285 fs::write(&path, tokens.join("\n")).unwrap();
286 let dict = RecDictionary::from_path(&path).unwrap();
287 fs::remove_file(path).ok();
288 dict
289 }
290
291 fn locate_ppocrv5_asset(file_name: &str) -> Option<PathBuf> {
292 let mut bases: Vec<PathBuf> = Vec::new();
293 if let Some(dir) = env::var_os("PURE_ONNX_OCR_FIXTURE_DIR") {
294 let env_path = PathBuf::from(dir);
295 bases.push(env_path.clone());
296 bases.push(env_path.join("models"));
297 }
298
299 let manifest = Path::new(env!("CARGO_MANIFEST_DIR"));
300 bases.push(manifest.join("tests").join("fixtures").join("models"));
301 bases.push(manifest.join("tests").join("fixtures"));
302 bases.push(manifest.join("models"));
303
304 for base in bases {
305 let ppocr_dir = base.join("ppocrv5");
306 let candidate = ppocr_dir.join(file_name);
307 if candidate.exists() {
308 return Some(candidate);
309 }
310
311 let alt = base.join(file_name);
312 if alt.exists() {
313 return Some(alt);
314 }
315 }
316
317 None
318 }
319
320 #[test]
321 fn recognition_inference_runs() -> TractResult<()> {
322 let model_path =
323 locate_ppocrv5_asset("rec.onnx").expect("expected SVTR model under models/ppocrv5/");
324
325 let session = RecInferenceSession::load(model_path)?;
326
327 let image = gradient_image(320, 160);
328 let preprocessor = RecPreProcessor::new(RecPreProcessorConfig::default());
329 let regions = vec![RecTextRegion {
330 x: 10,
331 y: 20,
332 width: 120,
333 height: 60,
334 }];
335 let batch = preprocessor
336 .process(&image, ®ions)
337 .expect("recognition preprocessing should succeed");
338
339 let output = session.run(&batch)?;
340 let shape = output.logits.dim();
341
342 assert_eq!(shape.0, 1);
343 assert!(shape.1 > 0);
344 assert!(shape.2 > 0);
345 assert_eq!(output.valid_timesteps.len(), 1);
346 assert!(output.valid_timesteps[0] <= shape.1);
347
348 Ok(())
349 }
350
351 #[test]
352 fn post_processor_decodes_with_fallback() {
353 let logits = Array3::from_shape_vec(
354 (2, 4, 4),
355 vec![
356 5.0, 0.1, -1.0, -2.0, -2.0, 4.5, 0.0, -3.0, -3.0, 4.2, -0.5, -3.5, -4.0, -1.0, 4.8, -3.0, -6.0, -5.0, 1.0, 4.5, 5.0, 0.0, -1.0, -2.0, 5.0, 0.0, -1.0, -2.0, 5.0, 0.0, -1.0, -2.0, ],
366 )
367 .unwrap();
368 let output = RecInferenceOutput {
369 logits,
370 valid_timesteps: vec![4, 1],
371 };
372
373 let dictionary = Arc::new(dictionary_from_tokens(&["a", "b"]));
374 let processor = RecPostProcessor::new(
375 Arc::clone(&dictionary),
376 RecPostProcessorConfig {
377 blank_id: 0,
378 fallback_token: "[UNK]".to_string(),
379 },
380 );
381
382 let sequences = processor.process(&output).expect("decoding succeeds");
383 assert_eq!(sequences.len(), 2);
384
385 assert_eq!(sequences[0].text, "ab");
386 assert_eq!(sequences[0].fallback_count, 0);
387
388 assert_eq!(sequences[1].text, "[UNK]");
389 assert_eq!(sequences[1].fallback_count, 1);
390 }
391}