1use anyhow::{bail, Context};
4use indicatif::{ProgressBar, ProgressStyle};
5use std::{
6 io::{Read, Write},
7 path::{Path, PathBuf},
8};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum OcrModel {
13 CctSV2Global,
15 CctXsV2Global,
17 CctSV1Global,
19 CctXsV1Global,
21 CctSReluV1Global,
23 CctXsReluV1Global,
25 ArgentinianPlatesCnn,
27 ArgentinianPlatesCnnSynth,
29 EuropeanPlatesMobileVitV2,
31 GlobalPlatesMobileVitV2,
33}
34
35impl OcrModel {
36 pub fn as_str(&self) -> &'static str {
38 match self {
39 OcrModel::CctSV2Global => "cct-s-v2-global-model",
40 OcrModel::CctXsV2Global => "cct-xs-v2-global-model",
41 OcrModel::CctSV1Global => "cct-s-v1-global-model",
42 OcrModel::CctXsV1Global => "cct-xs-v1-global-model",
43 OcrModel::CctSReluV1Global => "cct-s-relu-v1-global-model",
44 OcrModel::CctXsReluV1Global => "cct-xs-relu-v1-global-model",
45 OcrModel::ArgentinianPlatesCnn => "argentinian-plates-cnn-model",
46 OcrModel::ArgentinianPlatesCnnSynth => "argentinian-plates-cnn-synth-model",
47 OcrModel::EuropeanPlatesMobileVitV2 => "european-plates-mobile-vit-v2-model",
48 OcrModel::GlobalPlatesMobileVitV2 => "global-plates-mobile-vit-v2-model",
49 }
50 }
51
52 pub fn urls(&self) -> (&'static str, &'static str) {
54 match self {
55 OcrModel::CctSV2Global => (
56 concat!(
57 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
58 "arg-plates/cct_s_v2_global.onnx"
59 ),
60 concat!(
61 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
62 "arg-plates/cct_s_v2_global_plate_config.yaml"
63 ),
64 ),
65 OcrModel::CctXsV2Global => (
66 concat!(
67 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
68 "arg-plates/cct_xs_v2_global.onnx"
69 ),
70 concat!(
71 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
72 "arg-plates/cct_xs_v2_global_plate_config.yaml"
73 ),
74 ),
75 OcrModel::CctSV1Global => (
76 concat!(
77 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
78 "arg-plates/cct_s_v1_global.onnx"
79 ),
80 concat!(
81 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
82 "arg-plates/cct_s_v1_global_plate_config.yaml"
83 ),
84 ),
85 OcrModel::CctXsV1Global => (
86 concat!(
87 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
88 "arg-plates/cct_xs_v1_global.onnx"
89 ),
90 concat!(
91 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
92 "arg-plates/cct_xs_v1_global_plate_config.yaml"
93 ),
94 ),
95 OcrModel::CctSReluV1Global => (
96 concat!(
97 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
98 "arg-plates/cct_s_relu_v1_global.onnx"
99 ),
100 concat!(
101 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
102 "arg-plates/cct_s_relu_v1_global_plate_config.yaml"
103 ),
104 ),
105 OcrModel::CctXsReluV1Global => (
106 concat!(
107 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
108 "arg-plates/cct_xs_relu_v1_global.onnx"
109 ),
110 concat!(
111 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
112 "arg-plates/cct_xs_relu_v1_global_plate_config.yaml"
113 ),
114 ),
115 OcrModel::ArgentinianPlatesCnn => (
116 concat!(
117 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
118 "arg-plates/arg_cnn_ocr.onnx"
119 ),
120 concat!(
121 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
122 "arg-plates/arg_cnn_ocr_config.yaml"
123 ),
124 ),
125 OcrModel::ArgentinianPlatesCnnSynth => (
126 concat!(
127 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
128 "arg-plates/arg_cnn_ocr_synth.onnx"
129 ),
130 concat!(
131 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
132 "arg-plates/arg_cnn_ocr_config.yaml"
133 ),
134 ),
135 OcrModel::EuropeanPlatesMobileVitV2 => (
136 concat!(
137 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
138 "arg-plates/european_mobile_vit_v2_ocr.onnx"
139 ),
140 concat!(
141 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
142 "arg-plates/european_mobile_vit_v2_ocr_config.yaml"
143 ),
144 ),
145 OcrModel::GlobalPlatesMobileVitV2 => (
146 concat!(
147 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
148 "arg-plates/global_mobile_vit_v2_ocr.onnx"
149 ),
150 concat!(
151 "https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
152 "arg-plates/global_mobile_vit_v2_ocr_config.yaml"
153 ),
154 ),
155 }
156 }
157
158 pub fn from_str(s: &str) -> Option<Self> {
160 match s {
161 "cct-s-v2-global-model" => Some(OcrModel::CctSV2Global),
162 "cct-xs-v2-global-model" => Some(OcrModel::CctXsV2Global),
163 "cct-s-v1-global-model" => Some(OcrModel::CctSV1Global),
164 "cct-xs-v1-global-model" => Some(OcrModel::CctXsV1Global),
165 "cct-s-relu-v1-global-model" => Some(OcrModel::CctSReluV1Global),
166 "cct-xs-relu-v1-global-model" => Some(OcrModel::CctXsReluV1Global),
167 "argentinian-plates-cnn-model" => Some(OcrModel::ArgentinianPlatesCnn),
168 "argentinian-plates-cnn-synth-model" => Some(OcrModel::ArgentinianPlatesCnnSynth),
169 "european-plates-mobile-vit-v2-model" => Some(OcrModel::EuropeanPlatesMobileVitV2),
170 "global-plates-mobile-vit-v2-model" => Some(OcrModel::GlobalPlatesMobileVitV2),
171 _ => None,
172 }
173 }
174}
175
176pub fn default_cache_dir() -> PathBuf {
178 dirs::cache_dir()
179 .unwrap_or_else(|| PathBuf::from(".cache"))
180 .join("fast-plate-ocr")
181}
182
183fn download_file(url: &str, dest: &Path) -> anyhow::Result<()> {
185 let mut response = ureq::get(url)
186 .call()
187 .with_context(|| format!("HTTP request failed for {url}"))?;
188
189 let content_length = response
190 .headers()
191 .get("content-length")
192 .and_then(|v| v.to_str().ok())
193 .and_then(|v| v.parse::<u64>().ok())
194 .unwrap_or(0);
195
196 let file_name = dest
197 .file_name()
198 .map(|n| n.to_string_lossy().into_owned())
199 .unwrap_or_else(|| url.to_owned());
200
201 let pb = ProgressBar::new(content_length);
202 pb.set_style(
203 ProgressStyle::with_template(
204 "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
205 )
206 .unwrap()
207 .progress_chars("##-"),
208 );
209 pb.set_message(format!("Downloading {file_name}"));
210
211 let tmp = dest.with_extension("tmp");
213 {
214 let mut file =
215 std::fs::File::create(&tmp).with_context(|| format!("Cannot create {}", tmp.display()))?;
216
217 let mut buf = [0u8; 65_536];
218 let body = response.body_mut();
219 loop {
220 let n = body
221 .as_reader()
222 .read(&mut buf)
223 .context("Error reading HTTP body")?;
224 if n == 0 {
225 break;
226 }
227 file.write_all(&buf[..n]).context("Error writing file")?;
228 pb.inc(n as u64);
229 }
230 }
231 pb.finish_with_message(format!("Saved {file_name}"));
232
233 std::fs::rename(&tmp, dest)
234 .with_context(|| format!("Cannot rename {} → {}", tmp.display(), dest.display()))?;
235
236 Ok(())
237}
238
239pub fn download_model(
244 model: &OcrModel,
245 save_dir: Option<&Path>,
246 force_download: bool,
247) -> anyhow::Result<(PathBuf, PathBuf)> {
248 let cache_dir = match save_dir {
249 Some(d) => d.to_path_buf(),
250 None => default_cache_dir().join(model.as_str()),
251 };
252
253 if cache_dir.is_file() {
254 bail!("Expected a directory but found a file: {}", cache_dir.display());
255 }
256
257 std::fs::create_dir_all(&cache_dir)
258 .with_context(|| format!("Cannot create cache dir {}", cache_dir.display()))?;
259
260 let (model_url, config_url) = model.urls();
261
262 let model_filename = cache_dir.join(
263 model_url
264 .rsplit('/')
265 .next()
266 .expect("URL must have a path segment"),
267 );
268 let config_filename = cache_dir.join(
269 config_url
270 .rsplit('/')
271 .next()
272 .expect("URL must have a path segment"),
273 );
274
275 if force_download || !model_filename.is_file() {
276 download_file(model_url, &model_filename)?;
277 }
278 if force_download || !config_filename.is_file() {
279 download_file(config_url, &config_filename)?;
280 }
281
282 Ok((model_filename, config_filename))
283}