Skip to main content

fpo_rust/
hub.rs

1//! Model hub: known ONNX models with their download URLs and local caching.
2
3use anyhow::{bail, Context};
4use indicatif::{ProgressBar, ProgressStyle};
5use std::{
6    io::{Read, Write},
7    path::{Path, PathBuf},
8};
9
10/// The set of pre-trained OCR models available for download.
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum OcrModel {
13    /// Compact Convolutional Transformer – Small, v2, global plates.
14    CctSV2Global,
15    /// Compact Convolutional Transformer – XSmall, v2, global plates.
16    CctXsV2Global,
17    /// Compact Convolutional Transformer – Small, v1, global plates.
18    CctSV1Global,
19    /// Compact Convolutional Transformer – XSmall, v1, global plates.
20    CctXsV1Global,
21    /// Compact Convolutional Transformer – Small, ReLU, v1, global plates.
22    CctSReluV1Global,
23    /// Compact Convolutional Transformer – XSmall, ReLU, v1, global plates.
24    CctXsReluV1Global,
25    /// Argentinian plates CNN model.
26    ArgentinianPlatesCnn,
27    /// Argentinian plates CNN model trained with synthetic data.
28    ArgentinianPlatesCnnSynth,
29    /// European plates MobileVIT-v2 model.
30    EuropeanPlatesMobileVitV2,
31    /// Global plates (65+ countries) MobileVIT-v2 model.
32    GlobalPlatesMobileVitV2,
33}
34
35impl OcrModel {
36    /// Return the string identifier used to name cache directories.
37    pub fn as_str(&self) -> &'static str {
38        match self {
39            OcrModel::CctSV2Global => "cct-s-v2-global-model",
40            OcrModel::CctXsV2Global => "cct-xs-v2-global-model",
41            OcrModel::CctSV1Global => "cct-s-v1-global-model",
42            OcrModel::CctXsV1Global => "cct-xs-v1-global-model",
43            OcrModel::CctSReluV1Global => "cct-s-relu-v1-global-model",
44            OcrModel::CctXsReluV1Global => "cct-xs-relu-v1-global-model",
45            OcrModel::ArgentinianPlatesCnn => "argentinian-plates-cnn-model",
46            OcrModel::ArgentinianPlatesCnnSynth => "argentinian-plates-cnn-synth-model",
47            OcrModel::EuropeanPlatesMobileVitV2 => "european-plates-mobile-vit-v2-model",
48            OcrModel::GlobalPlatesMobileVitV2 => "global-plates-mobile-vit-v2-model",
49        }
50    }
51
52    /// Return `(onnx_url, config_url)` for this model.
53    pub fn urls(&self) -> (&'static str, &'static str) {
54        match self {
55            OcrModel::CctSV2Global => (
56                concat!(
57                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
58                    "arg-plates/cct_s_v2_global.onnx"
59                ),
60                concat!(
61                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
62                    "arg-plates/cct_s_v2_global_plate_config.yaml"
63                ),
64            ),
65            OcrModel::CctXsV2Global => (
66                concat!(
67                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
68                    "arg-plates/cct_xs_v2_global.onnx"
69                ),
70                concat!(
71                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
72                    "arg-plates/cct_xs_v2_global_plate_config.yaml"
73                ),
74            ),
75            OcrModel::CctSV1Global => (
76                concat!(
77                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
78                    "arg-plates/cct_s_v1_global.onnx"
79                ),
80                concat!(
81                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
82                    "arg-plates/cct_s_v1_global_plate_config.yaml"
83                ),
84            ),
85            OcrModel::CctXsV1Global => (
86                concat!(
87                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
88                    "arg-plates/cct_xs_v1_global.onnx"
89                ),
90                concat!(
91                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
92                    "arg-plates/cct_xs_v1_global_plate_config.yaml"
93                ),
94            ),
95            OcrModel::CctSReluV1Global => (
96                concat!(
97                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
98                    "arg-plates/cct_s_relu_v1_global.onnx"
99                ),
100                concat!(
101                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
102                    "arg-plates/cct_s_relu_v1_global_plate_config.yaml"
103                ),
104            ),
105            OcrModel::CctXsReluV1Global => (
106                concat!(
107                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
108                    "arg-plates/cct_xs_relu_v1_global.onnx"
109                ),
110                concat!(
111                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
112                    "arg-plates/cct_xs_relu_v1_global_plate_config.yaml"
113                ),
114            ),
115            OcrModel::ArgentinianPlatesCnn => (
116                concat!(
117                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
118                    "arg-plates/arg_cnn_ocr.onnx"
119                ),
120                concat!(
121                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
122                    "arg-plates/arg_cnn_ocr_config.yaml"
123                ),
124            ),
125            OcrModel::ArgentinianPlatesCnnSynth => (
126                concat!(
127                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
128                    "arg-plates/arg_cnn_ocr_synth.onnx"
129                ),
130                concat!(
131                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
132                    "arg-plates/arg_cnn_ocr_config.yaml"
133                ),
134            ),
135            OcrModel::EuropeanPlatesMobileVitV2 => (
136                concat!(
137                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
138                    "arg-plates/european_mobile_vit_v2_ocr.onnx"
139                ),
140                concat!(
141                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
142                    "arg-plates/european_mobile_vit_v2_ocr_config.yaml"
143                ),
144            ),
145            OcrModel::GlobalPlatesMobileVitV2 => (
146                concat!(
147                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
148                    "arg-plates/global_mobile_vit_v2_ocr.onnx"
149                ),
150                concat!(
151                    "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
152                    "arg-plates/global_mobile_vit_v2_ocr_config.yaml"
153                ),
154            ),
155        }
156    }
157
158    /// Parse a model from its string identifier.
159    pub fn from_str(s: &str) -> Option<Self> {
160        match s {
161            "cct-s-v2-global-model" => Some(OcrModel::CctSV2Global),
162            "cct-xs-v2-global-model" => Some(OcrModel::CctXsV2Global),
163            "cct-s-v1-global-model" => Some(OcrModel::CctSV1Global),
164            "cct-xs-v1-global-model" => Some(OcrModel::CctXsV1Global),
165            "cct-s-relu-v1-global-model" => Some(OcrModel::CctSReluV1Global),
166            "cct-xs-relu-v1-global-model" => Some(OcrModel::CctXsReluV1Global),
167            "argentinian-plates-cnn-model" => Some(OcrModel::ArgentinianPlatesCnn),
168            "argentinian-plates-cnn-synth-model" => Some(OcrModel::ArgentinianPlatesCnnSynth),
169            "european-plates-mobile-vit-v2-model" => Some(OcrModel::EuropeanPlatesMobileVitV2),
170            "global-plates-mobile-vit-v2-model" => Some(OcrModel::GlobalPlatesMobileVitV2),
171            _ => None,
172        }
173    }
174}
175
176/// Default cache directory: `~/.cache/fast-plate-ocr/`.
177pub fn default_cache_dir() -> PathBuf {
178    dirs::cache_dir()
179        .unwrap_or_else(|| PathBuf::from(".cache"))
180        .join("fast-plate-ocr")
181}
182
183/// Download a single file from `url` to `dest`, showing a progress bar.
184fn download_file(url: &str, dest: &Path) -> anyhow::Result<()> {
185    let mut response = ureq::get(url)
186        .call()
187        .with_context(|| format!("HTTP request failed for {url}"))?;
188
189    let content_length = response
190        .headers()
191        .get("content-length")
192        .and_then(|v| v.to_str().ok())
193        .and_then(|v| v.parse::<u64>().ok())
194        .unwrap_or(0);
195
196    let file_name = dest
197        .file_name()
198        .map(|n| n.to_string_lossy().into_owned())
199        .unwrap_or_else(|| url.to_owned());
200
201    let pb = ProgressBar::new(content_length);
202    pb.set_style(
203        ProgressStyle::with_template(
204            "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
205        )
206        .unwrap()
207        .progress_chars("##-"),
208    );
209    pb.set_message(format!("Downloading {file_name}"));
210
211    // Write to a temporary sibling file then rename (atomic-ish).
212    let tmp = dest.with_extension("tmp");
213    {
214        let mut file =
215            std::fs::File::create(&tmp).with_context(|| format!("Cannot create {}", tmp.display()))?;
216
217        let mut buf = [0u8; 65_536];
218        let body = response.body_mut();
219        loop {
220            let n = body
221                .as_reader()
222                .read(&mut buf)
223                .context("Error reading HTTP body")?;
224            if n == 0 {
225                break;
226            }
227            file.write_all(&buf[..n]).context("Error writing file")?;
228            pb.inc(n as u64);
229        }
230    }
231    pb.finish_with_message(format!("Saved {file_name}"));
232
233    std::fs::rename(&tmp, dest)
234        .with_context(|| format!("Cannot rename {} → {}", tmp.display(), dest.display()))?;
235
236    Ok(())
237}
238
239/// Download an OCR model from the hub and return `(onnx_path, config_path)`.
240///
241/// Files are cached in `save_dir` (defaults to `~/.cache/fast-plate-ocr/<model_name>/`).
242/// Set `force_download = true` to re-download even if the files already exist.
243pub fn download_model(
244    model: &OcrModel,
245    save_dir: Option<&Path>,
246    force_download: bool,
247) -> anyhow::Result<(PathBuf, PathBuf)> {
248    let cache_dir = match save_dir {
249        Some(d) => d.to_path_buf(),
250        None => default_cache_dir().join(model.as_str()),
251    };
252
253    if cache_dir.is_file() {
254        bail!("Expected a directory but found a file: {}", cache_dir.display());
255    }
256
257    std::fs::create_dir_all(&cache_dir)
258        .with_context(|| format!("Cannot create cache dir {}", cache_dir.display()))?;
259
260    let (model_url, config_url) = model.urls();
261
262    let model_filename = cache_dir.join(
263        model_url
264            .rsplit('/')
265            .next()
266            .expect("URL must have a path segment"),
267    );
268    let config_filename = cache_dir.join(
269        config_url
270            .rsplit('/')
271            .next()
272            .expect("URL must have a path segment"),
273    );
274
275    if force_download || !model_filename.is_file() {
276        download_file(model_url, &model_filename)?;
277    }
278    if force_download || !config_filename.is_file() {
279        download_file(config_url, &config_filename)?;
280    }
281
282    Ok((model_filename, config_filename))
283}