use anyhow::{bail, Context};
use indicatif::{ProgressBar, ProgressStyle};
use std::{
io::{Read, Write},
path::{Path, PathBuf},
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum OcrModel {
CctSV2Global,
CctXsV2Global,
CctSV1Global,
CctXsV1Global,
CctSReluV1Global,
CctXsReluV1Global,
ArgentinianPlatesCnn,
ArgentinianPlatesCnnSynth,
EuropeanPlatesMobileVitV2,
GlobalPlatesMobileVitV2,
}
impl OcrModel {
pub fn as_str(&self) -> &'static str {
match self {
OcrModel::CctSV2Global => "cct-s-v2-global-model",
OcrModel::CctXsV2Global => "cct-xs-v2-global-model",
OcrModel::CctSV1Global => "cct-s-v1-global-model",
OcrModel::CctXsV1Global => "cct-xs-v1-global-model",
OcrModel::CctSReluV1Global => "cct-s-relu-v1-global-model",
OcrModel::CctXsReluV1Global => "cct-xs-relu-v1-global-model",
OcrModel::ArgentinianPlatesCnn => "argentinian-plates-cnn-model",
OcrModel::ArgentinianPlatesCnnSynth => "argentinian-plates-cnn-synth-model",
OcrModel::EuropeanPlatesMobileVitV2 => "european-plates-mobile-vit-v2-model",
OcrModel::GlobalPlatesMobileVitV2 => "global-plates-mobile-vit-v2-model",
}
}
pub fn urls(&self) -> (&'static str, &'static str) {
match self {
OcrModel::CctSV2Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_v2_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_v2_global_plate_config.yaml"
),
),
OcrModel::CctXsV2Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_v2_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_v2_global_plate_config.yaml"
),
),
OcrModel::CctSV1Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_v1_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_v1_global_plate_config.yaml"
),
),
OcrModel::CctXsV1Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_v1_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_v1_global_plate_config.yaml"
),
),
OcrModel::CctSReluV1Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_relu_v1_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_s_relu_v1_global_plate_config.yaml"
),
),
OcrModel::CctXsReluV1Global => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_relu_v1_global.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/cct_xs_relu_v1_global_plate_config.yaml"
),
),
OcrModel::ArgentinianPlatesCnn => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/arg_cnn_ocr.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/arg_cnn_ocr_config.yaml"
),
),
OcrModel::ArgentinianPlatesCnnSynth => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/arg_cnn_ocr_synth.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/arg_cnn_ocr_config.yaml"
),
),
OcrModel::EuropeanPlatesMobileVitV2 => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/european_mobile_vit_v2_ocr.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/european_mobile_vit_v2_ocr_config.yaml"
),
),
OcrModel::GlobalPlatesMobileVitV2 => (
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/global_mobile_vit_v2_ocr.onnx"
),
concat!(
"https://github.com/ankandrew/cnn-ocr-lp/releases/download/",
"arg-plates/global_mobile_vit_v2_ocr_config.yaml"
),
),
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"cct-s-v2-global-model" => Some(OcrModel::CctSV2Global),
"cct-xs-v2-global-model" => Some(OcrModel::CctXsV2Global),
"cct-s-v1-global-model" => Some(OcrModel::CctSV1Global),
"cct-xs-v1-global-model" => Some(OcrModel::CctXsV1Global),
"cct-s-relu-v1-global-model" => Some(OcrModel::CctSReluV1Global),
"cct-xs-relu-v1-global-model" => Some(OcrModel::CctXsReluV1Global),
"argentinian-plates-cnn-model" => Some(OcrModel::ArgentinianPlatesCnn),
"argentinian-plates-cnn-synth-model" => Some(OcrModel::ArgentinianPlatesCnnSynth),
"european-plates-mobile-vit-v2-model" => Some(OcrModel::EuropeanPlatesMobileVitV2),
"global-plates-mobile-vit-v2-model" => Some(OcrModel::GlobalPlatesMobileVitV2),
_ => None,
}
}
}
pub fn default_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from(".cache"))
.join("fast-plate-ocr")
}
fn download_file(url: &str, dest: &Path) -> anyhow::Result<()> {
let mut response = ureq::get(url)
.call()
.with_context(|| format!("HTTP request failed for {url}"))?;
let content_length = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let file_name = dest
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| url.to_owned());
let pb = ProgressBar::new(content_length);
pb.set_style(
ProgressStyle::with_template(
"{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.progress_chars("##-"),
);
pb.set_message(format!("Downloading {file_name}"));
let tmp = dest.with_extension("tmp");
{
let mut file =
std::fs::File::create(&tmp).with_context(|| format!("Cannot create {}", tmp.display()))?;
let mut buf = [0u8; 65_536];
let body = response.body_mut();
loop {
let n = body
.as_reader()
.read(&mut buf)
.context("Error reading HTTP body")?;
if n == 0 {
break;
}
file.write_all(&buf[..n]).context("Error writing file")?;
pb.inc(n as u64);
}
}
pb.finish_with_message(format!("Saved {file_name}"));
std::fs::rename(&tmp, dest)
.with_context(|| format!("Cannot rename {} → {}", tmp.display(), dest.display()))?;
Ok(())
}
pub fn download_model(
model: &OcrModel,
save_dir: Option<&Path>,
force_download: bool,
) -> anyhow::Result<(PathBuf, PathBuf)> {
let cache_dir = match save_dir {
Some(d) => d.to_path_buf(),
None => default_cache_dir().join(model.as_str()),
};
if cache_dir.is_file() {
bail!("Expected a directory but found a file: {}", cache_dir.display());
}
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("Cannot create cache dir {}", cache_dir.display()))?;
let (model_url, config_url) = model.urls();
let model_filename = cache_dir.join(
model_url
.rsplit('/')
.next()
.expect("URL must have a path segment"),
);
let config_filename = cache_dir.join(
config_url
.rsplit('/')
.next()
.expect("URL must have a path segment"),
);
if force_download || !model_filename.is_file() {
download_file(model_url, &model_filename)?;
}
if force_download || !config_filename.is_file() {
download_file(config_url, &config_filename)?;
}
Ok((model_filename, config_filename))
}