pure_onnx_ocr/
recognition.rs

1use crate::ctc::{
2    CtcGreedyDecoder, CtcGreedyDecoderConfig, CtcGreedyDecoderError, DecodedSequence,
3};
4use crate::dictionary::RecDictionary;
5use crate::preprocessing::PreprocessedRecBatch;
6use ndarray::Array3;
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::path::Path;
10use std::sync::Arc;
11use tract_onnx::prelude::*;
12use tract_onnx::tract_core::anyhow::anyhow;
13
14/// Result of running SVTR recognition inference.
15#[derive(Debug, Clone)]
16pub struct RecInferenceOutput {
17    pub logits: Array3<f32>,
18    pub valid_timesteps: Vec<usize>,
19}
20
21/// Runnable inference session for SVTR recognition model.
22#[derive(Debug)]
23pub struct RecInferenceSession {
24    base_model: InferenceModel,
25    cache: RefCell<HashMap<(usize, u32), Arc<TypedRunnableModel<TypedModel>>>>,
26}
27
28impl RecInferenceSession {
29    pub fn load(model_path: impl AsRef<Path>) -> TractResult<Self> {
30        let model_path = model_path.as_ref();
31        println!("[RecInfer] Loading recognition model from {:?}", model_path);
32
33        let mut inference_model = tract_onnx::onnx()
34            .with_ignore_output_shapes(true)
35            .model_for_path(model_path)?;
36
37        let batch = inference_model.symbol_table.sym("batch");
38        let width = inference_model.symbol_table.sym("width");
39        inference_model.set_input_fact(
40            0,
41            InferenceFact::dt_shape(
42                f32::datum_type(),
43                tvec![batch.into(), TDim::from(3), TDim::from(48), width.into()],
44            ),
45        )?;
46
47        println!("[RecInfer] Recognition model prepared");
48        Ok(Self {
49            base_model: inference_model,
50            cache: RefCell::new(HashMap::new()),
51        })
52    }
53
54    pub fn run(&self, batch: &PreprocessedRecBatch) -> TractResult<RecInferenceOutput> {
55        let tensor_shape = batch.tensor.shape();
56        if tensor_shape.len() != 4 {
57            return Err(anyhow!(
58                "expected recognition input tensor to have 4 dimensions, got {:?}",
59                tensor_shape
60            )
61            .into());
62        }
63
64        let batch_size = tensor_shape[0];
65        let channel = tensor_shape[1];
66        let height = tensor_shape[2];
67        let width = tensor_shape[3];
68
69        println!(
70            "[RecInfer] Running inference with input shape {:?}",
71            tensor_shape
72        );
73
74        if channel != 3 || height != 48 {
75            return Err(anyhow!(
76                "expected recognition input to have shape [*, 3, 48, *], got {:?}",
77                tensor_shape
78            )
79            .into());
80        }
81
82        let plan = self.runnable_for_dims(batch_size, width as u32)?;
83        let outputs = plan.run(tvec!(batch.tensor.clone().into()))?;
84        let output_tensor = outputs
85            .into_iter()
86            .next()
87            .ok_or_else(|| anyhow!("SVTR model did not return any outputs"))?;
88
89        let view = output_tensor.to_array_view::<f32>()?;
90        if view.ndim() != 3 {
91            return Err(anyhow!(
92                "expected recognition output to have 3 dimensions, got {:?}",
93                view.shape()
94            )
95            .into());
96        }
97
98        let logits = view.into_dimensionality::<ndarray::Ix3>()?.to_owned();
99        let (logit_batch, time_steps, _classes) = logits.dim();
100        if logit_batch != batch_size {
101            return Err(anyhow!(
102                "batch dimension mismatch between input ({}) and output ({})",
103                batch_size,
104                logit_batch
105            )
106            .into());
107        }
108
109        let max_width = batch.max_width as f32;
110        let scale = if max_width > 0.0 {
111            time_steps as f32 / max_width
112        } else {
113            0.0
114        };
115        let valid_timesteps = batch
116            .valid_widths
117            .iter()
118            .map(|width| {
119                let mut steps = if scale > 0.0 {
120                    (scale * *width as f32).round() as isize
121                } else {
122                    time_steps as isize
123                };
124                if steps < 1 {
125                    steps = 1;
126                }
127                if steps as usize > time_steps {
128                    steps = time_steps as isize;
129                }
130                steps as usize
131            })
132            .collect::<Vec<_>>();
133
134        Ok(RecInferenceOutput {
135            logits,
136            valid_timesteps,
137        })
138    }
139
140    fn runnable_for_dims(
141        &self,
142        batch_size: usize,
143        width: u32,
144    ) -> TractResult<Arc<TypedRunnableModel<TypedModel>>> {
145        if let Some(plan) = self.cache.borrow().get(&(batch_size, width)) {
146            return Ok(Arc::clone(plan));
147        }
148
149        println!(
150            "[RecInfer] Preparing runnable model for batch {} width {}",
151            batch_size, width
152        );
153
154        let mut model = self.base_model.clone();
155        model.set_input_fact(
156            0,
157            InferenceFact::dt_shape(
158                f32::datum_type(),
159                tvec![
160                    TDim::from(batch_size as i64),
161                    TDim::from(3),
162                    TDim::from(48),
163                    TDim::from(width as i64)
164                ],
165            ),
166        )?;
167
168        let plan = model
169            .into_typed()?
170            .into_decluttered()?
171            .into_optimized()?
172            .into_runnable()?;
173
174        let plan = Arc::new(plan);
175        self.cache
176            .borrow_mut()
177            .insert((batch_size, width), Arc::clone(&plan));
178
179        Ok(plan)
180    }
181}
182
183/// Configuration for recognition post processing (CTC decoding stage).
184#[derive(Debug, Clone)]
185pub struct RecPostProcessorConfig {
186    pub blank_id: usize,
187    pub fallback_token: String,
188}
189
190impl Default for RecPostProcessorConfig {
191    fn default() -> Self {
192        Self {
193            blank_id: 0,
194            fallback_token: "[UNK]".to_string(),
195        }
196    }
197}
198
199/// Errors that can occur while decoding recognition logits into text.
200#[derive(Debug)]
201pub enum RecPostProcessorError {
202    Decoder(CtcGreedyDecoderError),
203}
204
205impl std::fmt::Display for RecPostProcessorError {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        match self {
208            RecPostProcessorError::Decoder(err) => write!(f, "ctc decoder failed: {}", err),
209        }
210    }
211}
212
213impl std::error::Error for RecPostProcessorError {
214    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215        match self {
216            RecPostProcessorError::Decoder(err) => Some(err),
217        }
218    }
219}
220
221impl From<CtcGreedyDecoderError> for RecPostProcessorError {
222    fn from(value: CtcGreedyDecoderError) -> Self {
223        RecPostProcessorError::Decoder(value)
224    }
225}
226
227/// Recognition post processor that converts logits into decoded text.
228#[derive(Debug, Clone)]
229pub struct RecPostProcessor {
230    decoder: CtcGreedyDecoder,
231    dictionary: Arc<RecDictionary>,
232}
233
234impl RecPostProcessor {
235    pub fn new(dictionary: Arc<RecDictionary>, config: RecPostProcessorConfig) -> Self {
236        let decoder = CtcGreedyDecoder::new(CtcGreedyDecoderConfig {
237            blank_id: config.blank_id,
238            fallback_token: Some(config.fallback_token),
239        });
240        Self {
241            decoder,
242            dictionary,
243        }
244    }
245
246    pub fn process(
247        &self,
248        output: &RecInferenceOutput,
249    ) -> Result<Vec<DecodedSequence>, RecPostProcessorError> {
250        self.decoder
251            .decode(&output.logits, &output.valid_timesteps, &self.dictionary)
252            .map_err(RecPostProcessorError::from)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::dictionary::RecDictionary;
260    use crate::preprocessing::{RecPreProcessor, RecPreProcessorConfig, RecTextRegion};
261    use image::{DynamicImage, ImageBuffer, Rgb};
262    use ndarray::Array3;
263    use std::env;
264    use std::fs;
265    use std::path::{Path, PathBuf};
266    use std::time::{SystemTime, UNIX_EPOCH};
267
268    fn gradient_image(width: u32, height: u32) -> DynamicImage {
269        let mut buffer = ImageBuffer::new(width, height);
270        for (x, y, pixel) in buffer.enumerate_pixels_mut() {
271            let base = ((x + y) % 256) as u8;
272            let green = base.saturating_add(16);
273            let blue = base.saturating_add(32);
274            *pixel = Rgb([base, green, blue]);
275        }
276        DynamicImage::ImageRgb8(buffer)
277    }
278
279    fn dictionary_from_tokens(tokens: &[&str]) -> RecDictionary {
280        let timestamp = SystemTime::now()
281            .duration_since(UNIX_EPOCH)
282            .unwrap()
283            .as_nanos();
284        let path = std::env::temp_dir().join(format!("rec_post_dict_{}.txt", timestamp));
285        fs::write(&path, tokens.join("\n")).unwrap();
286        let dict = RecDictionary::from_path(&path).unwrap();
287        fs::remove_file(path).ok();
288        dict
289    }
290
291    fn locate_ppocrv5_asset(file_name: &str) -> Option<PathBuf> {
292        let mut bases: Vec<PathBuf> = Vec::new();
293        if let Some(dir) = env::var_os("PURE_ONNX_OCR_FIXTURE_DIR") {
294            let env_path = PathBuf::from(dir);
295            bases.push(env_path.clone());
296            bases.push(env_path.join("models"));
297        }
298
299        let manifest = Path::new(env!("CARGO_MANIFEST_DIR"));
300        bases.push(manifest.join("tests").join("fixtures").join("models"));
301        bases.push(manifest.join("tests").join("fixtures"));
302        bases.push(manifest.join("models"));
303
304        for base in bases {
305            let ppocr_dir = base.join("ppocrv5");
306            let candidate = ppocr_dir.join(file_name);
307            if candidate.exists() {
308                return Some(candidate);
309            }
310
311            let alt = base.join(file_name);
312            if alt.exists() {
313                return Some(alt);
314            }
315        }
316
317        None
318    }
319
320    #[test]
321    fn recognition_inference_runs() -> TractResult<()> {
322        let model_path =
323            locate_ppocrv5_asset("rec.onnx").expect("expected SVTR model under models/ppocrv5/");
324
325        let session = RecInferenceSession::load(model_path)?;
326
327        let image = gradient_image(320, 160);
328        let preprocessor = RecPreProcessor::new(RecPreProcessorConfig::default());
329        let regions = vec![RecTextRegion {
330            x: 10,
331            y: 20,
332            width: 120,
333            height: 60,
334        }];
335        let batch = preprocessor
336            .process(&image, &regions)
337            .expect("recognition preprocessing should succeed");
338
339        let output = session.run(&batch)?;
340        let shape = output.logits.dim();
341
342        assert_eq!(shape.0, 1);
343        assert!(shape.1 > 0);
344        assert!(shape.2 > 0);
345        assert_eq!(output.valid_timesteps.len(), 1);
346        assert!(output.valid_timesteps[0] <= shape.1);
347
348        Ok(())
349    }
350
351    #[test]
352    fn post_processor_decodes_with_fallback() {
353        let logits = Array3::from_shape_vec(
354            (2, 4, 4),
355            vec![
356                5.0, 0.1, -1.0, -2.0, //
357                -2.0, 4.5, 0.0, -3.0, //
358                -3.0, 4.2, -0.5, -3.5, //
359                -4.0, -1.0, 4.8, -3.0, //
360                // second sequence with unknown indices
361                -6.0, -5.0, 1.0, 4.5, //
362                5.0, 0.0, -1.0, -2.0, //
363                5.0, 0.0, -1.0, -2.0, //
364                5.0, 0.0, -1.0, -2.0, //
365            ],
366        )
367        .unwrap();
368        let output = RecInferenceOutput {
369            logits,
370            valid_timesteps: vec![4, 1],
371        };
372
373        let dictionary = Arc::new(dictionary_from_tokens(&["a", "b"]));
374        let processor = RecPostProcessor::new(
375            Arc::clone(&dictionary),
376            RecPostProcessorConfig {
377                blank_id: 0,
378                fallback_token: "[UNK]".to_string(),
379            },
380        );
381
382        let sequences = processor.process(&output).expect("decoding succeeds");
383        assert_eq!(sequences.len(), 2);
384
385        assert_eq!(sequences[0].text, "ab");
386        assert_eq!(sequences[0].fallback_count, 0);
387
388        assert_eq!(sequences[1].text, "[UNK]");
389        assert_eq!(sequences[1].fallback_count, 1);
390    }
391}