pure_onnx_ocr/
lib.rs

1//! # Pure ONNX OCR
2//!
3//! A Pure Rust OCR pipeline that mirrors the PaddleOCR DBNet + SVTR stack.
4//! The crate exposes ergonomic builders and processing stages that let you
5//! load ONNX models, prepare image batches, and decode recognition logits
6//! without any C/C++ dependencies.
7//!
8//! Most consumers interact with [`OcrEngineBuilder`] to construct an
9//! [`OcrEngine`], then call [`OcrEngine::run_from_path`] or
10//! [`OcrEngine::run_from_image`].  Lower-level modules remain available
11//! when you need to plug specific stages into an existing pipeline.
12
13pub mod ctc;
14pub mod detection;
15pub mod dictionary;
16pub mod engine;
17pub mod postprocessing;
18pub mod preprocessing;
19pub mod recognition;
20
21/// Re-export of the CTC decoding utilities so applications can customise
22/// post-processing while keeping consistent types.
23pub use ctc::{CtcGreedyDecoder, CtcGreedyDecoderConfig, CtcGreedyDecoderError, DecodedSequence};
24/// Re-export of detection inference helpers for direct DBNet integration.
25pub use detection::{DetInferenceOutput, DetInferenceSession};
26pub use dictionary::{DictionaryError, RecDictionary};
27/// High-level façade providing an ergonomic OCR API.
28pub use engine::{
29    OcrEngine, OcrEngineBuilder, OcrEngineConfig, OcrError, OcrResult, OcrRunWithMetrics,
30    OcrTimings, StageTimings,
31};
32/// Geometry primitives surfaced at the crate root for convenience.
33pub use geo_types::{Point, Polygon};
34pub use postprocessing::{
35    DetPolygonScaler, DetPolygonScalerConfig, DetPolygonUnclipper, DetPolygonUnclipperConfig,
36    DetPostProcessor, DetPostProcessorConfig, DetPostProcessorError, DetScaleRounding,
37    DetUnclipLineJoin,
38};
39pub use preprocessing::{
40    DetPreProcessor, DetPreProcessorConfig, DetPreProcessorError, PreprocessedDetInput,
41    PreprocessedRecBatch, RecPreProcessor, RecPreProcessorConfig, RecPreProcessorError,
42    RecTextRegion,
43};
44pub use recognition::{
45    RecInferenceOutput, RecInferenceSession, RecPostProcessor, RecPostProcessorConfig,
46    RecPostProcessorError,
47};
48
49use std::path::Path;
50use tract_onnx::prelude::*;
51
52const DBNET_DUMMY_SHAPE: [usize; 4] = [1, 3, 320, 320];
53const SVTR_DUMMY_SHAPE: [usize; 4] = [1, 3, 48, 320];
54
55/// Run a `tract-onnx` dummy inference against a DBNet detection model.
56///
57/// The helper loads `model_path`, feeds a zero-filled tensor with the
58/// expected DBNet input shape, and returns the produced tensors.  The
59/// logs include timing information for each optimisation step, which is
60/// particularly helpful when first validating an ONNX export.
61///
62/// # Examples
63///
64/// ```no_run
65/// use pure_onnx_ocr::run_dbnet_dummy_inference;
66///
67/// let outputs = run_dbnet_dummy_inference("models/ppocrv5/det.onnx")
68///     .expect("model should load and execute");
69/// assert!(!outputs.is_empty());
70/// ```
71///
72/// # Errors
73///
74/// Returns [`tract_onnx::prelude::TractError`] when the model cannot be
75/// loaded, optimised, or executed with the provided dummy tensor.
76pub fn run_dbnet_dummy_inference(model_path: impl AsRef<Path>) -> TractResult<TVec<Tensor>> {
77    let dummy_input: Tensor = tract_ndarray::Array4::<f32>::zeros(DBNET_DUMMY_SHAPE)
78        .into_dyn()
79        .into();
80    run_dummy_inference(model_path, dummy_input, "DBNet")
81}
82
83/// Run a `tract-onnx` dummy inference against an SVTR recognition model.
84///
85/// The helper constructs a synthetic sinusoidal input tensor to exercise
86/// the model, runs end-to-end optimisation, and returns the resulting
87/// logits.  Use this to confirm that the SVTR export can be handled by
88/// `tract-onnx` before attempting full OCR integration.
89///
90/// # Examples
91///
92/// ```no_run
93/// use pure_onnx_ocr::run_svtr_dummy_inference;
94///
95/// let outputs = run_svtr_dummy_inference("models/ppocrv5/rec.onnx")
96///     .expect("model should load and execute");
97/// assert!(!outputs.is_empty());
98/// ```
99///
100/// # Errors
101///
102/// Returns [`tract_onnx::prelude::TractError`] when the model cannot be
103/// loaded, optimised, or executed.
104pub fn run_svtr_dummy_inference(model_path: impl AsRef<Path>) -> TractResult<TVec<Tensor>> {
105    let dummy_input: Tensor =
106        tract_ndarray::Array4::<f32>::from_shape_fn(SVTR_DUMMY_SHAPE, |(_, channel, row, col)| {
107            // 正規化された斜めグラデーション: チャンネルごとにスケールを変えて変化を持たせる
108            let spatial_size = (SVTR_DUMMY_SHAPE[2] * SVTR_DUMMY_SHAPE[3]) as f32;
109            let base = (row * SVTR_DUMMY_SHAPE[3] + col) as f32 / spatial_size;
110            let channel_scale = 0.1 * channel as f32;
111            (base + channel_scale).sin()
112        })
113        .into_dyn()
114        .into();
115    run_dummy_inference(model_path, dummy_input, "SVTR")
116}
117
118fn run_dummy_inference(
119    model_path: impl AsRef<Path>,
120    dummy_input: Tensor,
121    label: &str,
122) -> TractResult<TVec<Tensor>> {
123    let model_path = model_path.as_ref();
124    println!("[{}] Loading model from {:?}", label, model_path);
125
126    let start = std::time::Instant::now();
127
128    let mut model = tract_onnx::onnx()
129        .with_ignore_output_shapes(true)
130        .model_for_path(model_path)?;
131    println!("[{}] Model loaded, elapsed: {:?}", label, start.elapsed());
132
133    model.set_input_fact(0, InferenceFact::from(&dummy_input))?;
134    println!("[{}] Input fact set, elapsed: {:?}", label, start.elapsed());
135
136    println!(
137        "[{}] Starting model conversion to typed, elapsed: {:?}",
138        label,
139        start.elapsed()
140    );
141    let model = model.into_typed()?;
142
143    println!(
144        "[{}] Starting decluttering, elapsed: {:?}",
145        label,
146        start.elapsed()
147    );
148    let model = model.into_decluttered()?;
149
150    println!(
151        "[{}] Starting optimization, elapsed: {:?}",
152        label,
153        start.elapsed()
154    );
155    let model = model.into_optimized()?;
156
157    println!(
158        "[{}] Making runnable, elapsed: {:?}",
159        label,
160        start.elapsed()
161    );
162    let model = model.into_runnable()?;
163
164    println!("[{}] Total preparation time: {:?}", label, start.elapsed());
165
166    println!(
167        "[{}] Running inference, elapsed: {:?}",
168        label,
169        start.elapsed()
170    );
171    let outputs = model.run(tvec!(dummy_input.into()))?;
172    println!(
173        "[{}] Inference complete, elapsed: {:?}",
174        label,
175        start.elapsed()
176    );
177
178    Ok(outputs
179        .into_iter()
180        .map(|value| value.into_tensor())
181        .collect::<TVec<_>>())
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use std::path::Path;
188
189    #[test]
190    #[ignore = "dummy inference takes >60s; run with `cargo test -- --ignored`"]
191    fn dbnet_dummy_inference_runs_successfully() -> TractResult<()> {
192        let model_path = Path::new("models/ppocrv5/det.onnx");
193        assert!(
194            model_path.exists(),
195            "expected DBNet model at {:?} to exist",
196            model_path
197        );
198
199        println!("Starting inference test");
200        let outputs = run_dbnet_dummy_inference(model_path)?;
201        println!("Test completed with {} outputs", outputs.len());
202
203        assert!(
204            !outputs.is_empty(),
205            "inference should return at least one output tensor"
206        );
207
208        // 出力のシェイプを表示
209        for (i, tensor) in outputs.iter().enumerate() {
210            println!("Output tensor #{} shape: {:?}", i, tensor.shape());
211        }
212
213        Ok(())
214    }
215
216    #[test]
217    #[ignore = "dummy inference takes >60s; run with `cargo test -- --ignored`"]
218    fn svtr_dummy_inference_runs_successfully() -> TractResult<()> {
219        let model_path = Path::new("models/ppocrv5/rec.onnx");
220        assert!(
221            model_path.exists(),
222            "expected SVTR model at {:?} to exist",
223            model_path
224        );
225
226        println!("Starting SVTR inference test");
227        let outputs = run_svtr_dummy_inference(model_path)?;
228        println!("SVTR test completed with {} outputs", outputs.len());
229
230        assert!(
231            !outputs.is_empty(),
232            "SVTR inference should return at least one output tensor"
233        );
234
235        let first = &outputs[0];
236        println!("SVTR output tensor shape: {:?}", first.shape());
237
238        let shape = first.shape().to_vec();
239        assert_eq!(
240            shape.first().copied(),
241            Some(1),
242            "SVTR batch dimension should be 1"
243        );
244        assert!(
245            shape.iter().skip(1).all(|dim| *dim > 0),
246            "SVTR tensor dimensions after batch should be positive"
247        );
248
249        let view = first.to_array_view::<f32>()?;
250        let mut min = f32::INFINITY;
251        let mut max = f32::NEG_INFINITY;
252        for value in view.iter() {
253            min = min.min(*value);
254            max = max.max(*value);
255        }
256        println!("SVTR output value range: min={:.6}, max={:.6}", min, max);
257
258        assert!(
259            min.is_finite() && max.is_finite(),
260            "SVTR output values should be finite numbers"
261        );
262        assert!(
263            max > min,
264            "SVTR output values should have a non-zero dynamic range"
265        );
266
267        Ok(())
268    }
269}