Skip to main content

fpo_rust/
recognizer.rs

1//! ONNX-Runtime-based license-plate recognizer.
2
3use crate::{
4    config::PlateConfig,
5    hub::{download_model, OcrModel},
6    process::{images_to_batch, postprocess_output, read_and_resize_plate_image, PlatePrediction},
7};
8use anyhow::{bail, Context};
9use image::DynamicImage;
10use std::{
11    path::Path,
12    time::Instant,
13};
14use tract_onnx::prelude::*;
15
16// ---------------------------------------------------------------------------
17// Input type
18// ---------------------------------------------------------------------------
19
20/// A single plate input: either a file path or an already-decoded image.
21pub enum PlateInput<'a> {
22    Path(&'a Path),
23    Image(DynamicImage),
24}
25
26impl<'a> From<&'a str> for PlateInput<'a> {
27    fn from(s: &'a str) -> Self {
28        PlateInput::Path(Path::new(s))
29    }
30}
31
32impl<'a> From<&'a Path> for PlateInput<'a> {
33    fn from(p: &'a Path) -> Self {
34        PlateInput::Path(p)
35    }
36}
37
38impl From<DynamicImage> for PlateInput<'_> {
39    fn from(img: DynamicImage) -> Self {
40        PlateInput::Image(img)
41    }
42}
43
44// ---------------------------------------------------------------------------
45// Recognizer
46// ---------------------------------------------------------------------------
47
48type OnnxModel = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
49
50/// Inference class for license-plate OCR using tract-onnx.
51pub struct LicensePlateRecognizer {
52    /// Loaded ONNX plan (optimised graph).
53    model: OnnxModel,
54    /// Plate configuration (image size, alphabet, etc.).
55    pub config: PlateConfig,
56    /// Human-readable model name.
57    pub model_name: String,
58    /// Index of the plate output in `model.run()` results.
59    plate_output_idx: usize,
60    /// Index of the region output (if any).
61    region_output_idx: Option<usize>,
62    /// `true` when the model has a region head and the config defines regions.
63    has_region_head: bool,
64}
65
66impl LicensePlateRecognizer {
67    // -----------------------------------------------------------------------
68    // Constructors
69    // -----------------------------------------------------------------------
70
71    /// Build a recognizer from an existing ONNX model file and config file.
72    pub fn from_files(
73        onnx_model_path: impl AsRef<Path>,
74        plate_config_path: impl AsRef<Path>,
75    ) -> anyhow::Result<Self> {
76        let onnx_path = onnx_model_path.as_ref();
77        let cfg_path = plate_config_path.as_ref();
78
79        if !onnx_path.exists() {
80            bail!("ONNX model not found: {}", onnx_path.display());
81        }
82        if !cfg_path.exists() {
83            bail!("Plate config not found: {}", cfg_path.display());
84        }
85
86        let model_name = onnx_path
87            .file_stem()
88            .map(|s| s.to_string_lossy().into_owned())
89            .unwrap_or_else(|| "custom".to_owned());
90
91        let config = PlateConfig::from_yaml(cfg_path)?;
92
93        let model = Self::load_model(onnx_path, &config)?;
94
95        Self::from_model_and_config(model, config, model_name)
96    }
97
98    /// Download (if necessary) and load a model from the hub.
99    pub fn from_hub(model: OcrModel, force_download: bool) -> anyhow::Result<Self> {
100        let model_name = model.as_str().to_owned();
101        let (onnx_path, cfg_path) = download_model(&model, None, force_download)?;
102        let mut recognizer = Self::from_files(onnx_path, cfg_path)?;
103        recognizer.model_name = model_name;
104        Ok(recognizer)
105    }
106
107    /// Download (if necessary) and load a model from the hub, saving to a specific directory.
108    pub fn from_hub_to_dir(
109        model: OcrModel,
110        save_dir: &Path,
111        force_download: bool,
112    ) -> anyhow::Result<Self> {
113        let model_name = model.as_str().to_owned();
114        let (onnx_path, cfg_path) = download_model(&model, Some(save_dir), force_download)?;
115        let mut recognizer = Self::from_files(onnx_path, cfg_path)?;
116        recognizer.model_name = model_name;
117        Ok(recognizer)
118    }
119
120    // -----------------------------------------------------------------------
121    // Internal helpers
122    // -----------------------------------------------------------------------
123
124    fn load_model(onnx_path: &Path, config: &PlateConfig) -> anyhow::Result<OnnxModel> {
125        let h = config.img_height;
126        let w = config.img_width;
127        let c = config.num_channels();
128
129        let plan = tract_onnx::onnx()
130            .model_for_path(onnx_path)
131            .context("Cannot parse ONNX model")?
132            .with_input_fact(
133                0,
134                InferenceFact::dt_shape(u8::datum_type(), tvec![1usize, h as usize, w as usize, c as usize]),
135            )
136            .context("Cannot set input fact")?
137            .into_optimized()
138            .context("Cannot optimise ONNX model")?
139            .into_runnable()
140            .context("Cannot make model runnable")?;
141
142        Ok(plan)
143    }
144
145    fn from_model_and_config(
146        model: OnnxModel,
147        config: PlateConfig,
148        model_name: String,
149    ) -> anyhow::Result<Self> {
150        // tract does not expose named outputs by default; we rely on output index ordering.
151        // By convention in fast-plate-ocr ONNX exports:
152        //   output 0 = plate logits
153        //   output 1 = region logits (only when the model has a region head)
154        let num_outputs = model.model().output_outlets()?.len();
155        let plate_output_idx = 0;
156        let region_output_idx = if num_outputs > 1 { Some(1) } else { None };
157
158        let has_region_head = region_output_idx.is_some() && config.has_region_recognition();
159
160        if region_output_idx.is_none() && config.has_region_recognition() {
161            eprintln!(
162                "Warning: plate config declares regions but the model has only one output. \
163                 Region predictions will be disabled."
164            );
165        }
166        if region_output_idx.is_some() && !config.has_region_recognition() {
167            eprintln!(
168                "Warning: model has a second output but the plate config has no region list. \
169                 Region predictions will be disabled."
170            );
171        }
172
173        Ok(LicensePlateRecognizer {
174            model,
175            config,
176            model_name,
177            plate_output_idx,
178            region_output_idx,
179            has_region_head,
180        })
181    }
182
183    // -----------------------------------------------------------------------
184    // Inference
185    // -----------------------------------------------------------------------
186
187    /// Run plate recognition on one or more plate images.
188    ///
189    /// `inputs` is a slice of `PlateInput` (paths or pre-loaded images).  
190    /// Returns one `PlatePrediction` per input.
191    pub fn run(
192        &self,
193        inputs: &[PlateInput<'_>],
194        return_confidence: bool,
195        remove_pad_char: bool,
196    ) -> anyhow::Result<Vec<PlatePrediction>> {
197        if inputs.is_empty() {
198            return Ok(vec![]);
199        }
200
201        // --- Load & resize images ---
202        let imgs: Vec<DynamicImage> = inputs
203            .iter()
204            .map(|inp| match inp {
205                PlateInput::Path(p) => read_and_resize_plate_image(p, &self.config),
206                PlateInput::Image(img) => {
207                    crate::process::resize_image(
208                        img.clone(),
209                        self.config.img_height,
210                        self.config.img_width,
211                        &self.config.image_color_mode,
212                        self.config.keep_aspect_ratio,
213                        &self.config.interpolation,
214                        &self.config.padding_color,
215                    )
216                }
217            })
218            .collect::<anyhow::Result<Vec<_>>>()?;
219
220        // Run images one at a time (tract handles 1-image batches well; for larger batches
221        // we iterate and concatenate postprocessing results).
222        let mut plate_data_all: Vec<f32> = Vec::new();
223        let mut region_data_all: Vec<f32> = Vec::new();
224        let n = imgs.len();
225
226        for img in &imgs {
227            let raw = images_to_batch(std::slice::from_ref(img), &self.config);
228            let h = self.config.img_height as usize;
229            let w = self.config.img_width as usize;
230            let c = self.config.num_channels() as usize;
231
232            let input_tensor: Tensor =
233                tract_ndarray::Array4::<u8>::from_shape_vec((1, h, w, c), raw)
234                    .context("Cannot build input array")?
235                    .into();
236
237            let outputs = self
238                .model
239                .run(tvec![input_tensor.into()])
240                .context("Model run failed")?;
241
242            // Plate output (always present)
243            let plate_out = outputs
244                .get(self.plate_output_idx)
245                .context("Missing plate output")?;
246            let plate_view = plate_out
247                .to_array_view::<f32>()
248                .context("Cannot read plate output as f32")?;
249            plate_data_all.extend_from_slice(plate_view.as_slice().unwrap());
250
251            // Region output (optional)
252            if self.has_region_head {
253                if let Some(ridx) = self.region_output_idx {
254                    if let Some(region_out) = outputs.get(ridx) {
255                        let region_view = region_out
256                            .to_array_view::<f32>()
257                            .context("Cannot read region output as f32")?;
258                        region_data_all.extend_from_slice(region_view.as_slice().unwrap());
259                    }
260                }
261            }
262        }
263
264        postprocess_output(
265            &plate_data_all,
266            n,
267            self.config.max_plate_slots,
268            &self.config.alphabet,
269            self.config.pad_char,
270            remove_pad_char,
271            return_confidence,
272            if self.has_region_head && !region_data_all.is_empty() {
273                Some(&region_data_all)
274            } else {
275                None
276            },
277            if self.has_region_head {
278                self.config.plate_regions.as_deref()
279            } else {
280                None
281            },
282        )
283    }
284
285    /// Convenience wrapper for a single image.
286    pub fn run_one(
287        &self,
288        input: PlateInput<'_>,
289        return_confidence: bool,
290        remove_pad_char: bool,
291    ) -> anyhow::Result<PlatePrediction> {
292        let mut results = self.run(&[input], return_confidence, remove_pad_char)?;
293        if results.len() != 1 {
294            bail!("Expected exactly 1 result, got {}", results.len());
295        }
296        Ok(results.remove(0))
297    }
298
299    // -----------------------------------------------------------------------
300    // Benchmark
301    // -----------------------------------------------------------------------
302
303    /// Run a simple throughput benchmark and print the results to stdout.
304    ///
305    /// * `n_iter`   – number of timed iterations
306    /// * `batch_size` – images per "batch" (each is run individually due to tract's fixed shape)
307    /// * `warmup`   – number of warm-up iterations (not counted)
308    /// * `include_processing` – if `true`, preprocessing / postprocessing are included in the timing.
309    pub fn benchmark(
310        &self,
311        n_iter: usize,
312        batch_size: usize,
313        warmup: usize,
314        include_processing: bool,
315    ) -> anyhow::Result<()> {
316        use image::{DynamicImage, ImageBuffer, Luma, Rgb};
317
318        let h = self.config.img_height;
319        let w = self.config.img_width;
320        let c = self.config.num_channels();
321
322        let raw_pixels: Vec<u8> = (0..(h as usize * w as usize * c as usize))
323            .map(|i| (i % 256) as u8)
324            .collect();
325
326        let make_image = || -> DynamicImage {
327            if c == 1 {
328                DynamicImage::ImageLuma8(
329                    ImageBuffer::<Luma<u8>, _>::from_raw(w, h, raw_pixels.clone()).unwrap(),
330                )
331            } else {
332                DynamicImage::ImageRgb8(
333                    ImageBuffer::<Rgb<u8>, _>::from_raw(w, h, raw_pixels.clone()).unwrap(),
334                )
335            }
336        };
337
338        let run_once = || -> anyhow::Result<()> {
339            let img = make_image();
340            if include_processing {
341                let inputs = vec![PlateInput::Image(img)];
342                self.run(&inputs, false, true)?;
343            } else {
344                let raw = images_to_batch(std::slice::from_ref(&img), &self.config);
345                let input_tensor: Tensor =
346                    tract_ndarray::Array4::<u8>::from_shape_vec(
347                        (1, h as usize, w as usize, c as usize),
348                        raw,
349                    )
350                    .unwrap()
351                    .into();
352                self.model
353                    .run(tvec![input_tensor.into()])
354                    .context("benchmark run")?;
355            }
356            Ok(())
357        };
358
359        // Warm-up
360        for _ in 0..warmup {
361            run_once()?;
362        }
363
364        // Timed loop
365        let t0 = Instant::now();
366        for _ in 0..(n_iter * batch_size) {
367            run_once()?;
368        }
369        let elapsed_ms = t0.elapsed().as_secs_f64() * 1_000.0;
370
371        let total_plates = n_iter * batch_size;
372        let avg_ms = if total_plates > 0 {
373            elapsed_ms / n_iter as f64
374        } else {
375            0.0
376        };
377        let pps = if avg_ms > 0.0 {
378            (1_000.0 / avg_ms) * batch_size as f64
379        } else {
380            0.0
381        };
382
383        println!("─────────────────────────────────────────");
384        println!(" Model         : {}", self.model_name);
385        println!(" Batch size    : {batch_size}");
386        println!(" Warm-up iters : {warmup}");
387        println!(" Timed iters   : {n_iter}");
388        println!(" Avg time/batch: {avg_ms:.4} ms");
389        println!(" Plates/second : {pps:.2}");
390        println!("─────────────────────────────────────────");
391
392        Ok(())
393    }
394}
395