onnxruntime_ng/
download.rs1#[cfg(feature = "model-fetching")]
14use std::{
15 fs, io,
16 path::{Path, PathBuf},
17 time::Duration,
18};
19
20#[cfg(feature = "model-fetching")]
21use crate::error::{OrtDownloadError, Result};
22
23#[cfg(feature = "model-fetching")]
24use tracing::info;
25
26pub mod language;
27pub mod vision;
28
29#[derive(Debug, Clone)]
36pub enum AvailableOnnxModel {
37 Vision(vision::Vision),
39 Language(language::Language),
41}
42
43trait ModelUrl {
44 fn fetch_url(&self) -> &'static str;
45}
46
47impl ModelUrl for AvailableOnnxModel {
48 fn fetch_url(&self) -> &'static str {
49 match self {
50 AvailableOnnxModel::Vision(model) => model.fetch_url(),
51 AvailableOnnxModel::Language(model) => model.fetch_url(),
52 }
53 }
54}
55
56impl AvailableOnnxModel {
57 #[cfg(feature = "model-fetching")]
58 #[tracing::instrument]
59 pub(crate) fn download_to<P>(&self, download_dir: P) -> Result<PathBuf>
60 where
61 P: AsRef<Path> + std::fmt::Debug,
62 {
63 let url = self.fetch_url();
64
65 let model_filename = PathBuf::from(url.split('/').last().unwrap());
66 let model_filepath = download_dir.as_ref().join(model_filename);
67
68 if model_filepath.exists() {
69 info!(
70 model_filepath = format!("{}", model_filepath.display()).as_str(),
71 "File already exists, not re-downloading.",
72 );
73 Ok(model_filepath)
74 } else {
75 info!(
76 model_filepath = format!("{}", model_filepath.display()).as_str(),
77 url = format!("{:?}", url).as_str(),
78 "Downloading file, please wait....",
79 );
80
81 let resp = ureq::get(url)
82 .timeout(Duration::from_secs(180)) .call()
84 .map_err(Box::new)
85 .map_err(OrtDownloadError::UreqError)?;
86
87 assert!(resp.has("Content-Length"));
88 let len = resp
89 .header("Content-Length")
90 .and_then(|s| s.parse::<usize>().ok())
91 .unwrap();
92 info!(len, "Downloading {} bytes...", len);
93
94 let mut reader = resp.into_reader();
95
96 let f = fs::File::create(&model_filepath).unwrap();
97 let mut writer = io::BufWriter::new(f);
98
99 let bytes_io_count =
100 io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
101
102 if bytes_io_count == len as u64 {
103 Ok(model_filepath)
104 } else {
105 Err(OrtDownloadError::CopyError {
106 expected: len as u64,
107 io: bytes_io_count,
108 }
109 .into())
110 }
111 }
112 }
113}