1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#[cfg(feature = "model-fetching")]
use std::{
fs, io,
path::{Path, PathBuf},
time::Duration,
};
#[cfg(feature = "model-fetching")]
use crate::error::{OrtDownloadError, Result};
#[cfg(feature = "model-fetching")]
use tracing::info;
pub mod language;
pub mod vision;
#[derive(Debug, Clone)]
pub enum AvailableOnnxModel {
Vision(vision::Vision),
Language(language::Language),
}
trait ModelUrl {
fn fetch_url(&self) -> &'static str;
}
impl ModelUrl for AvailableOnnxModel {
fn fetch_url(&self) -> &'static str {
match self {
AvailableOnnxModel::Vision(model) => model.fetch_url(),
AvailableOnnxModel::Language(model) => model.fetch_url(),
}
}
}
impl AvailableOnnxModel {
#[cfg(feature = "model-fetching")]
#[tracing::instrument]
pub(crate) fn download_to<P>(&self, download_dir: P) -> Result<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug,
{
let url = self.fetch_url();
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
info!(
model_filepath = format!("{}", model_filepath.display()).as_str(),
"File already exists, not re-downloading.",
);
Ok(model_filepath)
} else {
info!(
model_filepath = format!("{}", model_filepath.display()).as_str(),
url = format!("{:?}", url).as_str(),
"Downloading file, please wait....",
);
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::UreqError)?;
assert!(resp.has("Content-Length"));
let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.unwrap();
info!(len, "Downloading {} bytes...", len);
let mut reader = resp.into_reader();
let f = fs::File::create(&model_filepath).unwrap();
let mut writer = io::BufWriter::new(f);
let bytes_io_count =
io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count,
}
.into())
}
}
}
}