1use crate::{
4 config::PlateConfig,
5 hub::{download_model, OcrModel},
6 process::{images_to_batch, postprocess_output, read_and_resize_plate_image, PlatePrediction},
7};
8use anyhow::{bail, Context};
9use image::DynamicImage;
10use std::{
11 path::Path,
12 time::Instant,
13};
14use tract_onnx::prelude::*;
15
16pub enum PlateInput<'a> {
22 Path(&'a Path),
23 Image(DynamicImage),
24}
25
26impl<'a> From<&'a str> for PlateInput<'a> {
27 fn from(s: &'a str) -> Self {
28 PlateInput::Path(Path::new(s))
29 }
30}
31
32impl<'a> From<&'a Path> for PlateInput<'a> {
33 fn from(p: &'a Path) -> Self {
34 PlateInput::Path(p)
35 }
36}
37
38impl From<DynamicImage> for PlateInput<'_> {
39 fn from(img: DynamicImage) -> Self {
40 PlateInput::Image(img)
41 }
42}
43
44type OnnxModel = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
49
50pub struct LicensePlateRecognizer {
52 model: OnnxModel,
54 pub config: PlateConfig,
56 pub model_name: String,
58 plate_output_idx: usize,
60 region_output_idx: Option<usize>,
62 has_region_head: bool,
64}
65
66impl LicensePlateRecognizer {
67 pub fn from_files(
73 onnx_model_path: impl AsRef<Path>,
74 plate_config_path: impl AsRef<Path>,
75 ) -> anyhow::Result<Self> {
76 let onnx_path = onnx_model_path.as_ref();
77 let cfg_path = plate_config_path.as_ref();
78
79 if !onnx_path.exists() {
80 bail!("ONNX model not found: {}", onnx_path.display());
81 }
82 if !cfg_path.exists() {
83 bail!("Plate config not found: {}", cfg_path.display());
84 }
85
86 let model_name = onnx_path
87 .file_stem()
88 .map(|s| s.to_string_lossy().into_owned())
89 .unwrap_or_else(|| "custom".to_owned());
90
91 let config = PlateConfig::from_yaml(cfg_path)?;
92
93 let model = Self::load_model(onnx_path, &config)?;
94
95 Self::from_model_and_config(model, config, model_name)
96 }
97
98 pub fn from_hub(model: OcrModel, force_download: bool) -> anyhow::Result<Self> {
100 let model_name = model.as_str().to_owned();
101 let (onnx_path, cfg_path) = download_model(&model, None, force_download)?;
102 let mut recognizer = Self::from_files(onnx_path, cfg_path)?;
103 recognizer.model_name = model_name;
104 Ok(recognizer)
105 }
106
107 pub fn from_hub_to_dir(
109 model: OcrModel,
110 save_dir: &Path,
111 force_download: bool,
112 ) -> anyhow::Result<Self> {
113 let model_name = model.as_str().to_owned();
114 let (onnx_path, cfg_path) = download_model(&model, Some(save_dir), force_download)?;
115 let mut recognizer = Self::from_files(onnx_path, cfg_path)?;
116 recognizer.model_name = model_name;
117 Ok(recognizer)
118 }
119
120 fn load_model(onnx_path: &Path, config: &PlateConfig) -> anyhow::Result<OnnxModel> {
125 let h = config.img_height;
126 let w = config.img_width;
127 let c = config.num_channels();
128
129 let plan = tract_onnx::onnx()
130 .model_for_path(onnx_path)
131 .context("Cannot parse ONNX model")?
132 .with_input_fact(
133 0,
134 InferenceFact::dt_shape(u8::datum_type(), tvec![1usize, h as usize, w as usize, c as usize]),
135 )
136 .context("Cannot set input fact")?
137 .into_optimized()
138 .context("Cannot optimise ONNX model")?
139 .into_runnable()
140 .context("Cannot make model runnable")?;
141
142 Ok(plan)
143 }
144
145 fn from_model_and_config(
146 model: OnnxModel,
147 config: PlateConfig,
148 model_name: String,
149 ) -> anyhow::Result<Self> {
150 let num_outputs = model.model().output_outlets()?.len();
155 let plate_output_idx = 0;
156 let region_output_idx = if num_outputs > 1 { Some(1) } else { None };
157
158 let has_region_head = region_output_idx.is_some() && config.has_region_recognition();
159
160 if region_output_idx.is_none() && config.has_region_recognition() {
161 eprintln!(
162 "Warning: plate config declares regions but the model has only one output. \
163 Region predictions will be disabled."
164 );
165 }
166 if region_output_idx.is_some() && !config.has_region_recognition() {
167 eprintln!(
168 "Warning: model has a second output but the plate config has no region list. \
169 Region predictions will be disabled."
170 );
171 }
172
173 Ok(LicensePlateRecognizer {
174 model,
175 config,
176 model_name,
177 plate_output_idx,
178 region_output_idx,
179 has_region_head,
180 })
181 }
182
183 pub fn run(
192 &self,
193 inputs: &[PlateInput<'_>],
194 return_confidence: bool,
195 remove_pad_char: bool,
196 ) -> anyhow::Result<Vec<PlatePrediction>> {
197 if inputs.is_empty() {
198 return Ok(vec![]);
199 }
200
201 let imgs: Vec<DynamicImage> = inputs
203 .iter()
204 .map(|inp| match inp {
205 PlateInput::Path(p) => read_and_resize_plate_image(p, &self.config),
206 PlateInput::Image(img) => {
207 crate::process::resize_image(
208 img.clone(),
209 self.config.img_height,
210 self.config.img_width,
211 &self.config.image_color_mode,
212 self.config.keep_aspect_ratio,
213 &self.config.interpolation,
214 &self.config.padding_color,
215 )
216 }
217 })
218 .collect::<anyhow::Result<Vec<_>>>()?;
219
220 let mut plate_data_all: Vec<f32> = Vec::new();
223 let mut region_data_all: Vec<f32> = Vec::new();
224 let n = imgs.len();
225
226 for img in &imgs {
227 let raw = images_to_batch(std::slice::from_ref(img), &self.config);
228 let h = self.config.img_height as usize;
229 let w = self.config.img_width as usize;
230 let c = self.config.num_channels() as usize;
231
232 let input_tensor: Tensor =
233 tract_ndarray::Array4::<u8>::from_shape_vec((1, h, w, c), raw)
234 .context("Cannot build input array")?
235 .into();
236
237 let outputs = self
238 .model
239 .run(tvec![input_tensor.into()])
240 .context("Model run failed")?;
241
242 let plate_out = outputs
244 .get(self.plate_output_idx)
245 .context("Missing plate output")?;
246 let plate_view = plate_out
247 .to_array_view::<f32>()
248 .context("Cannot read plate output as f32")?;
249 plate_data_all.extend_from_slice(plate_view.as_slice().unwrap());
250
251 if self.has_region_head {
253 if let Some(ridx) = self.region_output_idx {
254 if let Some(region_out) = outputs.get(ridx) {
255 let region_view = region_out
256 .to_array_view::<f32>()
257 .context("Cannot read region output as f32")?;
258 region_data_all.extend_from_slice(region_view.as_slice().unwrap());
259 }
260 }
261 }
262 }
263
264 postprocess_output(
265 &plate_data_all,
266 n,
267 self.config.max_plate_slots,
268 &self.config.alphabet,
269 self.config.pad_char,
270 remove_pad_char,
271 return_confidence,
272 if self.has_region_head && !region_data_all.is_empty() {
273 Some(®ion_data_all)
274 } else {
275 None
276 },
277 if self.has_region_head {
278 self.config.plate_regions.as_deref()
279 } else {
280 None
281 },
282 )
283 }
284
285 pub fn run_one(
287 &self,
288 input: PlateInput<'_>,
289 return_confidence: bool,
290 remove_pad_char: bool,
291 ) -> anyhow::Result<PlatePrediction> {
292 let mut results = self.run(&[input], return_confidence, remove_pad_char)?;
293 if results.len() != 1 {
294 bail!("Expected exactly 1 result, got {}", results.len());
295 }
296 Ok(results.remove(0))
297 }
298
299 pub fn benchmark(
310 &self,
311 n_iter: usize,
312 batch_size: usize,
313 warmup: usize,
314 include_processing: bool,
315 ) -> anyhow::Result<()> {
316 use image::{DynamicImage, ImageBuffer, Luma, Rgb};
317
318 let h = self.config.img_height;
319 let w = self.config.img_width;
320 let c = self.config.num_channels();
321
322 let raw_pixels: Vec<u8> = (0..(h as usize * w as usize * c as usize))
323 .map(|i| (i % 256) as u8)
324 .collect();
325
326 let make_image = || -> DynamicImage {
327 if c == 1 {
328 DynamicImage::ImageLuma8(
329 ImageBuffer::<Luma<u8>, _>::from_raw(w, h, raw_pixels.clone()).unwrap(),
330 )
331 } else {
332 DynamicImage::ImageRgb8(
333 ImageBuffer::<Rgb<u8>, _>::from_raw(w, h, raw_pixels.clone()).unwrap(),
334 )
335 }
336 };
337
338 let run_once = || -> anyhow::Result<()> {
339 let img = make_image();
340 if include_processing {
341 let inputs = vec![PlateInput::Image(img)];
342 self.run(&inputs, false, true)?;
343 } else {
344 let raw = images_to_batch(std::slice::from_ref(&img), &self.config);
345 let input_tensor: Tensor =
346 tract_ndarray::Array4::<u8>::from_shape_vec(
347 (1, h as usize, w as usize, c as usize),
348 raw,
349 )
350 .unwrap()
351 .into();
352 self.model
353 .run(tvec![input_tensor.into()])
354 .context("benchmark run")?;
355 }
356 Ok(())
357 };
358
359 for _ in 0..warmup {
361 run_once()?;
362 }
363
364 let t0 = Instant::now();
366 for _ in 0..(n_iter * batch_size) {
367 run_once()?;
368 }
369 let elapsed_ms = t0.elapsed().as_secs_f64() * 1_000.0;
370
371 let total_plates = n_iter * batch_size;
372 let avg_ms = if total_plates > 0 {
373 elapsed_ms / n_iter as f64
374 } else {
375 0.0
376 };
377 let pps = if avg_ms > 0.0 {
378 (1_000.0 / avg_ms) * batch_size as f64
379 } else {
380 0.0
381 };
382
383 println!("─────────────────────────────────────────");
384 println!(" Model : {}", self.model_name);
385 println!(" Batch size : {batch_size}");
386 println!(" Warm-up iters : {warmup}");
387 println!(" Timed iters : {n_iter}");
388 println!(" Avg time/batch: {avg_ms:.4} ms");
389 println!(" Plates/second : {pps:.2}");
390 println!("─────────────────────────────────────────");
391
392 Ok(())
393 }
394}
395